diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..76547e9ee --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +bin +*.wasm +.git \ No newline at end of file diff --git a/.editorconfig b/.editorconfig deleted file mode 100644 index 1fee56179..000000000 --- a/.editorconfig +++ /dev/null @@ -1,18 +0,0 @@ -root = true - -[*] -charset = utf-8 - -end_of_line = lf -insert_final_newline = true -trim_trailing_whitespace = true - -[*.go] -indent_style = tab -indent_size = 4 - -[*.md] -trim_trailing_whitespace = false - -[*.{yml,yaml}] -indent_style = space diff --git a/.gitignore b/.gitignore index 1de8887ce..34c5b6805 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,6 @@ dendrite.yaml # Log files *.log* + +# Generated code +cmd/dendrite-demo-yggdrasil/embed/fs*.go diff --git a/README.md b/README.md index 49f9ca840..aa9060f1b 100644 --- a/README.md +++ b/README.md @@ -3,15 +3,15 @@ Dendrite will be a second-generation Matrix homeserver written in Go. It's still very much a work in progress, but installation instructions can be -found in [INSTALL.md](INSTALL.md). It is not recommended to use Dendrite as a +found in [INSTALL.md](docs/INSTALL.md). It is not recommended to use Dendrite as a production homeserver at this time. -An overview of the design can be found in [DESIGN.md](DESIGN.md). +An overview of the design can be found in [DESIGN.md](docs/DESIGN.md). # Contributing Everyone is welcome to help out and contribute! See -[CONTRIBUTING.md](CONTRIBUTING.md) to get started! +[CONTRIBUTING.md](docs/CONTRIBUTING.md) to get started! Please note that, as of February 2020, Dendrite now only targets Go 1.13 or later. Please ensure that you are using at least Go 1.13 when developing for diff --git a/appservice/api/query.go b/appservice/api/query.go index afd5c5d76..29e374aca 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -20,16 +20,11 @@ package api import ( "context" "database/sql" - "errors" - "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/common" - commonHTTP "github.com/matrix-org/dendrite/common/http" - opentracing "github.com/opentracing/opentracing-go" ) // RoomAliasExistsRequest is a request to an application service @@ -83,60 +78,9 @@ type AppServiceQueryAPI interface { ) error } -// AppServiceRoomAliasExistsPath is the HTTP path for the RoomAliasExists API -const AppServiceRoomAliasExistsPath = "/api/appservice/RoomAliasExists" - -// AppServiceUserIDExistsPath is the HTTP path for the UserIDExists API -const AppServiceUserIDExistsPath = "/api/appservice/UserIDExists" - -// httpAppServiceQueryAPI contains the URL to an appservice query API and a -// reference to a httpClient used to reach it -type httpAppServiceQueryAPI struct { - appserviceURL string - httpClient *http.Client -} - -// NewAppServiceQueryAPIHTTP creates a AppServiceQueryAPI implemented by talking -// to a HTTP POST API. -// If httpClient is nil an error is returned -func NewAppServiceQueryAPIHTTP( - appserviceURL string, - httpClient *http.Client, -) (AppServiceQueryAPI, error) { - if httpClient == nil { - return nil, errors.New("NewRoomserverAliasAPIHTTP: httpClient is ") - } - return &httpAppServiceQueryAPI{appserviceURL, httpClient}, nil -} - -// RoomAliasExists implements AppServiceQueryAPI -func (h *httpAppServiceQueryAPI) RoomAliasExists( - ctx context.Context, - request *RoomAliasExistsRequest, - response *RoomAliasExistsResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceRoomAliasExists") - defer span.Finish() - - apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// UserIDExists implements AppServiceQueryAPI -func (h *httpAppServiceQueryAPI) UserIDExists( - ctx context.Context, - request *UserIDExistsRequest, - response *UserIDExistsResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceUserIDExists") - defer span.Finish() - - apiURL := h.appserviceURL + AppServiceUserIDExistsPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - // RetrieveUserProfile is a wrapper that queries both the local database and // application services for a given user's profile +// TODO: Remove this, it's called from federationapi and clientapi but is a pure function func RetrieveUserProfile( ctx context.Context, userID string, @@ -165,7 +109,7 @@ func RetrieveUserProfile( // If no user exists, return if !userResp.UserIDExists { - return nil, common.ErrProfileNoExists + return nil, eventutil.ErrProfileNoExists } // Try to query the user from the local database again diff --git a/appservice/appservice.go b/appservice/appservice.go index 181799879..728690414 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -20,36 +20,35 @@ import ( "sync" "time" + "github.com/gorilla/mux" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" + "github.com/matrix-org/dendrite/appservice/inthttp" "github.com/matrix-org/dendrite/appservice/query" - "github.com/matrix-org/dendrite/appservice/routing" "github.com/matrix-org/dendrite/appservice/storage" "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/workers" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/common/basecomponent" - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/dendrite/common/transactions" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/setup" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/sirupsen/logrus" ) -// SetupAppServiceAPIComponent sets up and registers HTTP handlers for the AppServices -// component. -func SetupAppServiceAPIComponent( - base *basecomponent.BaseDendrite, - accountsDB accounts.Database, - deviceDB devices.Database, - federation *gomatrixserverlib.FederationClient, - roomserverAliasAPI roomserverAPI.RoomserverAliasAPI, - roomserverQueryAPI roomserverAPI.RoomserverQueryAPI, - transactionsCache *transactions.Cache, +// AddInternalRoutes registers HTTP handlers for internal API calls +func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceQueryAPI) { + inthttp.AddRoutes(queryAPI, router) +} + +// NewInternalAPI returns a concerete implementation of the internal API. Callers +// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. +func NewInternalAPI( + base *setup.BaseDendrite, + userAPI userapi.UserInternalAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, ) appserviceAPI.AppServiceQueryAPI { // Create a connection to the appservice postgres DB - appserviceDB, err := storage.NewDatabase(string(base.Cfg.Database.AppService)) + appserviceDB, err := storage.NewDatabase(string(base.Cfg.Database.AppService), base.Cfg.DbProperties()) if err != nil { logrus.WithError(err).Panicf("failed to connect to appservice db") } @@ -67,7 +66,7 @@ func SetupAppServiceAPIComponent( workerStates[i] = ws // Create bot account for this AS if it doesn't already exist - if err = generateAppServiceAccount(accountsDB, deviceDB, appservice); err != nil { + if err = generateAppServiceAccount(userAPI, appservice); err != nil { logrus.WithFields(logrus.Fields{ "appservice": appservice.ID, }).WithError(err).Panicf("failed to generate bot account for appservice") @@ -76,57 +75,55 @@ func SetupAppServiceAPIComponent( // Create appserivce query API with an HTTP client that will be used for all // outbound and inbound requests (inbound only for the internal API) - appserviceQueryAPI := query.AppServiceQueryAPI{ + appserviceQueryAPI := &query.AppServiceQueryAPI{ HTTPClient: &http.Client{ Timeout: time.Second * 30, }, Cfg: base.Cfg, } - appserviceQueryAPI.SetupHTTP(http.DefaultServeMux) - - consumer := consumers.NewOutputRoomEventConsumer( - base.Cfg, base.KafkaConsumer, accountsDB, appserviceDB, - roomserverQueryAPI, roomserverAliasAPI, workerStates, - ) - if err := consumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") + // Only consume if we actually have ASes to track, else we'll just chew cycles needlessly. + // We can't add ASes at runtime so this is safe to do. + if len(workerStates) > 0 { + consumer := consumers.NewOutputRoomEventConsumer( + base.Cfg, base.KafkaConsumer, appserviceDB, + rsAPI, workerStates, + ) + if err := consumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") + } } // Create application service transaction workers if err := workers.SetupTransactionWorkers(appserviceDB, workerStates); err != nil { logrus.WithError(err).Panicf("failed to start app service transaction workers") } - - // Set up HTTP Endpoints - routing.Setup( - base.APIMux, base.Cfg, roomserverQueryAPI, roomserverAliasAPI, - accountsDB, federation, transactionsCache, - ) - - return &appserviceQueryAPI + return appserviceQueryAPI } // generateAppServiceAccounts creates a dummy account based off the // `sender_localpart` field of each application service if it doesn't // exist already func generateAppServiceAccount( - accountsDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, as config.ApplicationService, ) error { - ctx := context.Background() - - // Create an account for the application service - acc, err := accountsDB.CreateAccount(ctx, as.SenderLocalpart, "", as.ID) + var accRes userapi.PerformAccountCreationResponse + err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, + Localpart: as.SenderLocalpart, + AppServiceID: as.ID, + OnConflict: userapi.ConflictUpdate, + }, &accRes) if err != nil { return err - } else if acc == nil { - // This account already exists - return nil } - - // Create a dummy device with a dummy token for the application service - _, err = deviceDB.CreateDevice(ctx, as.SenderLocalpart, nil, as.ASToken, &as.SenderLocalpart) + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ + Localpart: as.SenderLocalpart, + AccessToken: as.ASToken, + DeviceID: &as.SenderLocalpart, + DeviceDisplayName: &as.SenderLocalpart, + }, &devRes) return err } diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 6ae58e85c..4c0156b2c 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -20,23 +20,20 @@ import ( "github.com/matrix-org/dendrite/appservice/storage" "github.com/matrix-org/dendrite/appservice/types" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/Shopify/sarama" log "github.com/sirupsen/logrus" - sarama "gopkg.in/Shopify/sarama.v1" ) // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { - roomServerConsumer *common.ContinualConsumer - db accounts.Database + roomServerConsumer *internal.ContinualConsumer asDB storage.Database - query api.RoomserverQueryAPI - alias api.RoomserverAliasAPI + rsAPI api.RoomserverInternalAPI serverName string workerStates []types.ApplicationServiceWorkerState } @@ -46,23 +43,19 @@ type OutputRoomEventConsumer struct { func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, - store accounts.Database, appserviceDB storage.Database, - queryAPI api.RoomserverQueryAPI, - aliasAPI api.RoomserverAliasAPI, + rsAPI api.RoomserverInternalAPI, workerStates []types.ApplicationServiceWorkerState, ) *OutputRoomEventConsumer { - consumer := common.ContinualConsumer{ + consumer := internal.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputRoomEvent), Consumer: kafkaConsumer, - PartitionStore: store, + PartitionStore: appserviceDB, } s := &OutputRoomEventConsumer{ roomServerConsumer: &consumer, - db: store, asDB: appserviceDB, - query: queryAPI, - alias: aliasAPI, + rsAPI: rsAPI, serverName: string(cfg.Matrix.ServerName), workerStates: workerStates, } @@ -94,60 +87,13 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { return nil } - ev := output.NewRoomEvent.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "type": ev.Type(), - }).Info("appservice received an event from roomserver") - - missingEvents, err := s.lookupMissingStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev) - if err != nil { - return err - } - events := append(missingEvents, ev) + events := []gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} + events = append(events, output.NewRoomEvent.AddStateEvents...) // Send event to any relevant application services return s.filterRoomserverEvents(context.TODO(), events) } -// lookupMissingStateEvents looks up the state events that are added by a new event, -// and returns any not already present. -func (s *OutputRoomEventConsumer) lookupMissingStateEvents( - addsStateEventIDs []string, event gomatrixserverlib.HeaderedEvent, -) ([]gomatrixserverlib.HeaderedEvent, error) { - // Fast path if there aren't any new state events. - if len(addsStateEventIDs) == 0 { - return []gomatrixserverlib.HeaderedEvent{}, nil - } - - // Fast path if the only state event added is the event itself. - if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { - return []gomatrixserverlib.HeaderedEvent{}, nil - } - - result := []gomatrixserverlib.HeaderedEvent{} - missing := []string{} - for _, id := range addsStateEventIDs { - if id != event.EventID() { - // If the event isn't the current one, add it to the list of events - // to retrieve from the roomserver - missing = append(missing, id) - } - } - - // Request the missing events from the roomserver - eventReq := api.QueryEventsByIDRequest{EventIDs: missing} - var eventResp api.QueryEventsByIDResponse - if err := s.query.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { - return nil, err - } - - result = append(result, eventResp.Events...) - - return result, nil -} - // filterRoomserverEvents takes in events and decides whether any of them need // to be passed on to an external application service. It does this by checking // each namespace of each registered application service, and if there is a @@ -200,7 +146,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont // Check all known room aliases of the room the event came from queryReq := api.GetAliasesForRoomIDRequest{RoomID: event.RoomID()} var queryRes api.GetAliasesForRoomIDResponse - if err := s.alias.GetAliasesForRoomID(ctx, &queryReq, &queryRes); err == nil { + if err := s.rsAPI.GetAliasesForRoomID(ctx, &queryReq, &queryRes); err == nil { for _, alias := range queryRes.Aliases { if appservice.IsInterestedInRoomAlias(alias) { return true diff --git a/appservice/inthttp/client.go b/appservice/inthttp/client.go new file mode 100644 index 000000000..7e3cb208f --- /dev/null +++ b/appservice/inthttp/client.go @@ -0,0 +1,63 @@ +package inthttp + +import ( + "context" + "errors" + "net/http" + + "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/opentracing/opentracing-go" +) + +// HTTP paths for the internal HTTP APIs +const ( + AppServiceRoomAliasExistsPath = "/appservice/RoomAliasExists" + AppServiceUserIDExistsPath = "/appservice/UserIDExists" +) + +// httpAppServiceQueryAPI contains the URL to an appservice query API and a +// reference to a httpClient used to reach it +type httpAppServiceQueryAPI struct { + appserviceURL string + httpClient *http.Client +} + +// NewAppserviceClient creates a AppServiceQueryAPI implemented by talking +// to a HTTP POST API. +// If httpClient is nil an error is returned +func NewAppserviceClient( + appserviceURL string, + httpClient *http.Client, +) (api.AppServiceQueryAPI, error) { + if httpClient == nil { + return nil, errors.New("NewRoomserverAliasAPIHTTP: httpClient is ") + } + return &httpAppServiceQueryAPI{appserviceURL, httpClient}, nil +} + +// RoomAliasExists implements AppServiceQueryAPI +func (h *httpAppServiceQueryAPI) RoomAliasExists( + ctx context.Context, + request *api.RoomAliasExistsRequest, + response *api.RoomAliasExistsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceRoomAliasExists") + defer span.Finish() + + apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// UserIDExists implements AppServiceQueryAPI +func (h *httpAppServiceQueryAPI) UserIDExists( + ctx context.Context, + request *api.UserIDExistsRequest, + response *api.UserIDExistsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceUserIDExists") + defer span.Finish() + + apiURL := h.appserviceURL + AppServiceUserIDExistsPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go new file mode 100644 index 000000000..009b7b5db --- /dev/null +++ b/appservice/inthttp/server.go @@ -0,0 +1,43 @@ +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/util" +) + +// AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux. +func AddRoutes(a api.AppServiceQueryAPI, internalAPIMux *mux.Router) { + internalAPIMux.Handle( + AppServiceRoomAliasExistsPath, + httputil.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse { + var request api.RoomAliasExistsRequest + var response api.RoomAliasExistsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := a.RoomAliasExists(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + AppServiceUserIDExistsPath, + httputil.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse { + var request api.UserIDExistsRequest + var response api.UserIDExistsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := a.UserIDExists(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/appservice/query/query.go b/appservice/query/query.go index fde3ab09c..fa3844f68 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -18,15 +18,12 @@ package query import ( "context" - "encoding/json" "net/http" "net/url" "time" "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/internal/config" opentracing "github.com/opentracing/opentracing-go" log "github.com/sirupsen/logrus" ) @@ -179,36 +176,3 @@ func makeHTTPClient() *http.Client { Timeout: time.Second * 30, } } - -// SetupHTTP adds the AppServiceQueryPAI handlers to the http.ServeMux. This -// handles and muxes incoming api requests the to internal AppServiceQueryAPI. -func (a *AppServiceQueryAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle( - api.AppServiceRoomAliasExistsPath, - common.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse { - var request api.RoomAliasExistsRequest - var response api.RoomAliasExistsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := a.RoomAliasExists(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.AppServiceUserIDExistsPath, - common.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse { - var request api.UserIDExistsRequest - var response api.UserIDExistsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := a.UserIDExists(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) -} diff --git a/appservice/routing/routing.go b/appservice/routing/routing.go deleted file mode 100644 index 42fa80520..000000000 --- a/appservice/routing/routing.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2018 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package routing - -import ( - "net/http" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/dendrite/common/transactions" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" -) - -const pathPrefixApp = "/_matrix/app/v1" - -// Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client -// to clients which need to make outbound HTTP requests. -// -// Due to Setup being used to call many other functions, a gocyclo nolint is -// applied: -// nolint: gocyclo -func Setup( - apiMux *mux.Router, cfg *config.Dendrite, // nolint: unparam - queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam - accountDB accounts.Database, // nolint: unparam - federation *gomatrixserverlib.FederationClient, // nolint: unparam - transactionsCache *transactions.Cache, // nolint: unparam -) { - appMux := apiMux.PathPrefix(pathPrefixApp).Subrouter() - - appMux.Handle("/alias", - common.MakeExternalAPI("alias", func(req *http.Request) util.JSONResponse { - // TODO: Implement - return util.JSONResponse{ - Code: http.StatusOK, - JSON: nil, - } - }), - ).Methods(http.MethodGet, http.MethodOptions) - appMux.Handle("/user", - common.MakeExternalAPI("user", func(req *http.Request) util.JSONResponse { - // TODO: Implement - return util.JSONResponse{ - Code: http.StatusOK, - JSON: nil, - } - }), - ).Methods(http.MethodGet, http.MethodOptions) -} diff --git a/appservice/storage/interface.go b/appservice/storage/interface.go index 25d35af6c..735e2f90a 100644 --- a/appservice/storage/interface.go +++ b/appservice/storage/interface.go @@ -17,10 +17,12 @@ package storage import ( "context" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { + internal.PartitionStorer StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error) diff --git a/appservice/storage/postgres/storage.go b/appservice/storage/postgres/storage.go index e145eeee2..03f331d64 100644 --- a/appservice/storage/postgres/storage.go +++ b/appservice/storage/postgres/storage.go @@ -27,21 +27,25 @@ import ( // Database stores events intended to be later sent to application services type Database struct { + sqlutil.PartitionOffsetStatements events eventsStatements txnID txnStatements db *sql.DB } // NewDatabase opens a new database -func NewDatabase(dataSourceName string) (*Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { var result Database var err error - if result.db, err = sqlutil.Open("postgres", dataSourceName); err != nil { + if result.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } if err = result.prepare(); err != nil { return nil, err } + if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil { + return nil, err + } return &result, nil } diff --git a/appservice/storage/sqlite3/storage.go b/appservice/storage/sqlite3/storage.go index 0cd1e4abc..cb55c8d94 100644 --- a/appservice/storage/sqlite3/storage.go +++ b/appservice/storage/sqlite3/storage.go @@ -20,7 +20,6 @@ import ( "database/sql" // Import SQLite database driver - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" _ "github.com/mattn/go-sqlite3" @@ -28,6 +27,7 @@ import ( // Database stores events intended to be later sent to application services type Database struct { + sqlutil.PartitionOffsetStatements events eventsStatements txnID txnStatements db *sql.DB @@ -37,12 +37,19 @@ type Database struct { func NewDatabase(dataSourceName string) (*Database, error) { var result Database var err error - if result.db, err = sqlutil.Open(common.SQLiteDriverName(), dataSourceName); err != nil { + cs, err := sqlutil.ParseFileURI(dataSourceName) + if err != nil { + return nil, err + } + if result.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } if err = result.prepare(); err != nil { return nil, err } + if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil { + return nil, err + } return &result, nil } diff --git a/appservice/storage/storage.go b/appservice/storage/storage.go index 9fbd2a1f3..c848d15d7 100644 --- a/appservice/storage/storage.go +++ b/appservice/storage/storage.go @@ -21,19 +21,22 @@ import ( "github.com/matrix-org/dendrite/appservice/storage/postgres" "github.com/matrix-org/dendrite/appservice/storage/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" ) -func NewDatabase(dataSourceName string) (Database, error) { +// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// and sets DB connection parameters +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.NewDatabase(dataSourceName) + return postgres.NewDatabase(dataSourceName, dbProperties) } switch uri.Scheme { case "postgres": - return postgres.NewDatabase(dataSourceName) + return postgres.NewDatabase(dataSourceName, dbProperties) case "file": return sqlite3.NewDatabase(dataSourceName) default: - return postgres.NewDatabase(dataSourceName) + return postgres.NewDatabase(dataSourceName, dbProperties) } } diff --git a/appservice/storage/storage_wasm.go b/appservice/storage/storage_wasm.go index 2bd1433f9..1d6c4b4a9 100644 --- a/appservice/storage/storage_wasm.go +++ b/appservice/storage/storage_wasm.go @@ -19,9 +19,13 @@ import ( "net/url" "github.com/matrix-org/dendrite/appservice/storage/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" ) -func NewDatabase(dataSourceName string) (Database, error) { +func NewDatabase( + dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam +) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return nil, fmt.Errorf("Cannot use postgres implementation") diff --git a/appservice/types/types.go b/appservice/types/types.go index aac731550..b6386df67 100644 --- a/appservice/types/types.go +++ b/appservice/types/types.go @@ -15,7 +15,7 @@ package types import ( "sync" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" ) const ( diff --git a/appservice/workers/transaction_scheduler.go b/appservice/workers/transaction_scheduler.go index 10c7ef911..63ec58aa1 100644 --- a/appservice/workers/transaction_scheduler.go +++ b/appservice/workers/transaction_scheduler.go @@ -21,11 +21,12 @@ import ( "fmt" "math" "net/http" + "net/url" "time" "github.com/matrix-org/dendrite/appservice/storage" "github.com/matrix-org/dendrite/appservice/types" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -207,9 +208,15 @@ func send( txnID int, transaction []byte, ) error { - // POST a transaction to our AS - address := fmt.Sprintf("%s/transactions/%d", appservice.URL, txnID) - resp, err := client.Post(address, "application/json", bytes.NewBuffer(transaction)) + // PUT a transaction to our AS + // https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid + address := fmt.Sprintf("%s/transactions/%d?access_token=%s", appservice.URL, txnID, url.QueryEscape(appservice.HSToken)) + req, err := http.NewRequest("PUT", address, bytes.NewBuffer(transaction)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) if err != nil { return err } diff --git a/are-we-synapse-yet.list b/are-we-synapse-yet.list index 5a900b3ef..f59f80675 100644 --- a/are-we-synapse-yet.list +++ b/are-we-synapse-yet.list @@ -28,6 +28,10 @@ log POST /login returns the same device_id as that in the request log POST /login can log in as a user with just the local part of the id log POST /login as non-existing user is rejected log POST /login wrong password is rejected +log Interactive authentication types include SSO +log Can perform interactive authentication with SSO +log The user must be consistent through an interactive authentication session with SSO +log The operation must be consistent through an interactive authentication session v1s GET /events initially v1s GET /initialSync initially csa Version responds 200 OK with valid structure @@ -44,6 +48,7 @@ dev DELETE /device/{deviceId} dev DELETE /device/{deviceId} requires UI auth user to match device owner dev DELETE /device/{deviceId} with no body gives a 401 dev The deleted device must be consistent through an interactive auth session +dev Users receive device_list updates for their own devices pre GET /presence/:user_id/status fetches initial status pre PUT /presence/:user_id/status updates my presence crm POST /createRoom makes a public room @@ -92,8 +97,8 @@ rst PUT power_levels should not explode if the old power levels were empty rst Both GET and PUT work rct POST /rooms/:room_id/receipt can create receipts red POST /rooms/:room_id/read_markers can create read marker -med POST /media/v1/upload can create an upload -med GET /media/v1/download can fetch the value again +med POST /media/r0/upload can create an upload +med GET /media/r0/download can fetch the value again cap GET /capabilities is present and well formed for registered user cap GET /r0/capabilities is not public reg Register with a recaptcha @@ -450,6 +455,19 @@ rmv User can invite remote user to room with version 5 rmv Remote user can backfill in a room with version 5 rmv Can reject invites over federation for rooms with version 5 rmv Can receive redactions from regular users over federation in room version 5 +rmv User can create and send/receive messages in a room with version 6 +rmv User can create and send/receive messages in a room with version 6 (2 subtests) +rmv local user can join room with version 6 +rmv User can invite local user to room with version 6 +rmv remote user can join room with version 6 +rmv User can invite remote user to room with version 6 +rmv Remote user can backfill in a room with version 6 +rmv Can reject invites over federation for rooms with version 6 +rmv Can receive redactions from regular users over federation in room version 6 +rmv Inbound federation rejects invites which include invalid JSON for room version 6 +rmv Outbound federation rejects invite response which include invalid JSON for room version 6 +rmv Inbound federation rejects invite rejections which include invalid JSON for room version 6 +rmv Server rejects invalid JSON in a version 6 room pre Presence changes are reported to local room members f,pre Presence changes are also reported to remote room members pre Presence changes to UNAVAILABLE are reported to local room members @@ -531,11 +549,11 @@ std Can recv device messages until they are acknowledged std Device messages with the same txn_id are deduplicated std Device messages wake up /sync std Can recv device messages over federation -std Device messages over federation wake up /sync +fsd Device messages over federation wake up /sync std Can send messages with a wildcard device id std Can send messages with a wildcard device id to two devices std Wildcard device messages wake up /sync -std Wildcard device messages over federation wake up /sync +fsd Wildcard device messages over federation wake up /sync adm /whois nsp /purge_history nsp /purge_history by ts @@ -573,6 +591,7 @@ frv A pair of servers can establish a join in a v2 room fsj Outbound federation rejects send_join responses with no m.room.create event frv Outbound federation rejects m.room.create events with an unknown room version fsj Event with an invalid signature in the send_join response should not cause room join to fail +fsj Inbound: send_join rejects invalid JSON for room version 6 fed Outbound federation can send events fed Inbound federation can receive events fed Inbound federation can receive redacted events @@ -631,6 +650,7 @@ fst Name/topic keys are correct fau Remote servers cannot set power levels in rooms without existing powerlevels fau Remote servers should reject attempts by non-creators to set the power levels fau Inbound federation rejects typing notifications from wrong remote +fau Users cannot set notifications powerlevel higher than their own fed Forward extremities remain so even after the next events are populated as outliers fau Banned servers cannot send events fau Banned servers cannot /make_join @@ -827,4 +847,10 @@ syn Multiple calls to /sync should not cause 500 errors gst Guest user can call /events on another world_readable room (SYN-606) gst Real user can call /events on another world_readable room (SYN-606) gst Events come down the correct room -pub Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list \ No newline at end of file +pub Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list +std Can send a to-device message to two users which both receive it using /sync +fme Outbound federation will ignore a missing event with bad JSON for room version 6 +fbk Outbound federation rejects backfill containing invalid JSON for events in room version 6 +jso Invalid JSON integers +jso Invalid JSON floats +jso Invalid JSON special values \ No newline at end of file diff --git a/are-we-synapse-yet.py b/are-we-synapse-yet.py index ffed8d384..30979a129 100755 --- a/are-we-synapse-yet.py +++ b/are-we-synapse-yet.py @@ -33,6 +33,7 @@ import sys test_mappings = { "nsp": "Non-Spec API", + "unk": "Unknown API (no group specified)", "f": "Federation", # flag to mark test involves federation "federation_apis": { @@ -50,6 +51,7 @@ test_mappings = { "fpb": "Public Room API", "fdk": "Device Key APIs", "fed": "Federation API", + "fsd": "Send-to-Device APIs", }, "client_apis": { @@ -99,6 +101,7 @@ test_mappings = { "ign": "Ignore Users", "udr": "User Directory APIs", "app": "Application Services API", + "jso": "Enforced canonical JSON", }, } @@ -212,7 +215,8 @@ def main(results_tap_path, verbose): # } }, "nonspec": { - "nsp": {} + "nsp": {}, + "unk": {} }, } with open(results_tap_path, "r") as f: @@ -223,7 +227,7 @@ def main(results_tap_path, verbose): name = test_result["name"] group_id = test_name_to_group_id.get(name) if not group_id: - raise Exception("The test '%s' doesn't have a group" % (name,)) + summary["nonspec"]["unk"][name] = test_result["ok"] if group_id == "nsp": summary["nonspec"]["nsp"][name] = test_result["ok"] elif group_id in test_mappings["federation_apis"]: diff --git a/build-dendritejs.sh b/build-dendritejs.sh new file mode 100755 index 000000000..cd42a6bee --- /dev/null +++ b/build-dendritejs.sh @@ -0,0 +1,4 @@ +#!/bin/bash -eu + +export GIT_COMMIT=$(git rev-list -1 HEAD) && \ +GOOS=js GOARCH=wasm go build -ldflags "-X main.GitCommit=$GIT_COMMIT" -o main.wasm ./cmd/dendritejs \ No newline at end of file diff --git a/build.sh b/build.sh index 3ef148891..087f4ae72 100755 --- a/build.sh +++ b/build.sh @@ -3,4 +3,6 @@ # Put installed packages into ./bin export GOBIN=$PWD/`dirname $0`/bin -go install -v $PWD/`dirname $0`/cmd/... \ No newline at end of file +go install -v $PWD/`dirname $0`/cmd/... + +GOOS=js GOARCH=wasm go build -o main.wasm ./cmd/dendritejs diff --git a/build/docker/DendriteJS.Dockerfile b/build/docker/DendriteJS.Dockerfile new file mode 100644 index 000000000..4467c9c70 --- /dev/null +++ b/build/docker/DendriteJS.Dockerfile @@ -0,0 +1,111 @@ +# This dockerfile will build dendritejs and hook it up to riot-web, build that then dump the +# resulting HTML/JS onto an nginx container for hosting. It requires no specific build context +# as it pulls archives straight from github branches. +# +# $ docker build -t dendritejs -f DendriteJS.Dockerfile . +# $ docker run --rm -p 8888:80 dendritejs +# Then visit http://localhost:8888 +FROM golang:1.13.7-alpine3.11 AS gobuild + +# Download and build dendrite +WORKDIR /build +ADD https://github.com/matrix-org/dendrite/archive/master.tar.gz /build/master.tar.gz +RUN tar xvfz master.tar.gz +WORKDIR /build/dendrite-master +RUN GOOS=js GOARCH=wasm go build -o main.wasm ./cmd/dendritejs + + +FROM node:14-stretch AS jsbuild +# apparently some deps require python +RUN apt-get update && apt-get -y install python + +# Download riot-web and libp2p repos +WORKDIR /build +ADD https://github.com/matrix-org/go-http-js-libp2p/archive/master.tar.gz /build/libp2p.tar.gz +RUN tar xvfz libp2p.tar.gz +ADD https://github.com/vector-im/riot-web/archive/matthew/p2p.tar.gz /build/p2p.tar.gz +RUN tar xvfz p2p.tar.gz + +# Install deps for riot-web, symlink in libp2p repo and build that too +WORKDIR /build/riot-web-matthew-p2p +RUN yarn install +RUN ln -s /build/go-http-js-libp2p-master /build/riot-web-matthew-p2p/node_modules/go-http-js-libp2p +RUN (cd node_modules/go-http-js-libp2p && yarn install) +COPY --from=gobuild /build/dendrite-master/main.wasm ./src/vector/dendrite.wasm +# build it all +RUN yarn build:p2p + +SHELL ["/bin/bash", "-c"] +RUN echo $'\ +{ \n\ + "default_server_config": { \n\ + "m.homeserver": { \n\ + "base_url": "https://p2p.riot.im", \n\ + "server_name": "p2p.riot.im" \n\ + }, \n\ + "m.identity_server": { \n\ + "base_url": "https://vector.im" \n\ + } \n\ + }, \n\ + "disable_custom_urls": false, \n\ + "disable_guests": true, \n\ + "disable_login_language_selector": false, \n\ + "disable_3pid_login": true, \n\ + "brand": "Riot", \n\ + "integrations_ui_url": "https://scalar.vector.im/", \n\ + "integrations_rest_url": "https://scalar.vector.im/api", \n\ + "integrations_widgets_urls": [ \n\ + "https://scalar.vector.im/_matrix/integrations/v1", \n\ + "https://scalar.vector.im/api", \n\ + "https://scalar-staging.vector.im/_matrix/integrations/v1", \n\ + "https://scalar-staging.vector.im/api", \n\ + "https://scalar-staging.riot.im/scalar/api" \n\ + ], \n\ + "integrations_jitsi_widget_url": "https://scalar.vector.im/api/widgets/jitsi.html", \n\ + "bug_report_endpoint_url": "https://riot.im/bugreports/submit", \n\ + "defaultCountryCode": "GB", \n\ + "showLabsSettings": false, \n\ + "features": { \n\ + "feature_pinning": "labs", \n\ + "feature_custom_status": "labs", \n\ + "feature_custom_tags": "labs", \n\ + "feature_state_counters": "labs" \n\ + }, \n\ + "default_federate": true, \n\ + "default_theme": "light", \n\ + "roomDirectory": { \n\ + "servers": [ \n\ + "matrix.org" \n\ + ] \n\ + }, \n\ + "welcomeUserId": "", \n\ + "piwik": { \n\ + "url": "https://piwik.riot.im/", \n\ + "whitelistedHSUrls": ["https://matrix.org"], \n\ + "whitelistedISUrls": ["https://vector.im", "https://matrix.org"], \n\ + "siteId": 1 \n\ + }, \n\ + "enable_presence_by_hs_url": { \n\ + "https://matrix.org": false, \n\ + "https://matrix-client.matrix.org": false \n\ + }, \n\ + "settingDefaults": { \n\ + "breadcrumbs": true \n\ + } \n\ +}' > webapp/config.json + +FROM nginx +# Add "Service-Worker-Allowed: /" header so the worker can sniff traffic on this domain rather +# than just the path this gets hosted under. NB this newline echo syntax only works on bash. +SHELL ["/bin/bash", "-c"] +RUN echo $'\ +server { \n\ + listen 80; \n\ + add_header \'Service-Worker-Allowed\' \'/\'; \n\ + location / { \n\ + root /usr/share/nginx/html; \n\ + index index.html index.htm; \n\ + } \n\ +}' > /etc/nginx/conf.d/default.conf +RUN sed -i 's/}/ application\/wasm wasm;\n}/g' /etc/nginx/mime.types +COPY --from=jsbuild /build/riot-web-matthew-p2p/webapp /usr/share/nginx/html diff --git a/build/docker/Dockerfile b/build/docker/Dockerfile new file mode 100644 index 000000000..d8e07681f --- /dev/null +++ b/build/docker/Dockerfile @@ -0,0 +1,10 @@ +FROM docker.io/golang:1.13.7-alpine3.11 AS builder + +RUN apk --update --no-cache add bash build-base + +WORKDIR /build + +COPY . /build + +RUN mkdir -p bin +RUN sh ./build.sh \ No newline at end of file diff --git a/build/docker/Dockerfile.component b/build/docker/Dockerfile.component new file mode 100644 index 000000000..13634391a --- /dev/null +++ b/build/docker/Dockerfile.component @@ -0,0 +1,13 @@ +FROM matrixdotorg/dendrite:latest AS base + +FROM alpine:latest + +ARG component=monolith +ENV entrypoint=${component} + +COPY --from=base /build/bin/${component} /usr/bin + +VOLUME /etc/dendrite +WORKDIR /etc/dendrite + +ENTRYPOINT /usr/bin/${entrypoint} $@ \ No newline at end of file diff --git a/build/docker/README.md b/build/docker/README.md new file mode 100644 index 000000000..45d96d1cb --- /dev/null +++ b/build/docker/README.md @@ -0,0 +1,70 @@ +# Docker images + +These are Docker images for Dendrite! + +## Dockerfiles + +The `Dockerfile` builds the base image which contains all of the Dendrite +components. The `Dockerfile.component` file takes the given component, as +specified with `--buildarg component=` from the base image and produce +smaller component-specific images, which are substantially smaller and do +not contain the Go toolchain etc. + +## Compose files + +There are three sample `docker-compose` files: + +- `docker-compose.deps.yml` which runs the Postgres and Kafka prerequisites +- `docker-compose.monolith.yml` which runs a monolith Dendrite deployment +- `docker-compose.polylith.yml` which runs a polylith Dendrite deployment + +## Configuration + +The `docker-compose` files refer to the `/etc/dendrite` volume as where the +runtime config should come from. The mounted folder must contain: + +- `dendrite.yaml` configuration file (based on the sample `dendrite-config.yaml` + in the `docker/config` folder in the [Dendrite repository](https://github.com/matrix-org/dendrite) +- `matrix_key.pem` server key, as generated using `cmd/generate-keys` +- `server.crt` certificate file +- `server.key` private key file for the above certificate + +To generate keys: + +``` +go run github.com/matrix-org/dendrite/cmd/generate-keys \ + --private-key=matrix_key.pem \ + --tls-cert=server.crt \ + --tls-key=server.key +``` + +## Starting Dendrite + +Once in place, start the dependencies: + +``` +docker-compose -f docker-compose.deps.yml up +``` + +Wait a few seconds for Kafka and Postgres to finish starting up, and then start a monolith: + +``` +docker-compose -f docker-compose.monolith.yml up +``` + +... or start the polylith components: + +``` +docker-compose -f docker-compose.polylith.yml up +``` + +## Building the images + +The `docker/images-build.sh` script will build the base image, followed by +all of the component images. + +The `docker/images-push.sh` script will push them to Docker Hub (subject +to permissions). + +If you wish to build and push your own images, rename `matrixdotorg/dendrite` to +the name of another Docker Hub repository in `images-build.sh` and `images-push.sh`. diff --git a/docker/dendrite-docker.yml b/build/docker/config/dendrite-config.yaml similarity index 94% rename from docker/dendrite-docker.yml rename to build/docker/config/dendrite-config.yaml index a72ff3ddc..53d9f7b02 100644 --- a/docker/dendrite-docker.yml +++ b/build/docker/config/dendrite-config.yaml @@ -80,7 +80,7 @@ kafka: # Kafka can be used both with a monolithic server and when running the # components as separate servers. # If enabled database.naffka must also be specified. - use_naffka: true + use_naffka: false # The names of the kafka topics to use. topics: output_room_event: roomserverOutput @@ -101,7 +101,7 @@ database: public_rooms_api: "postgres://dendrite:itsasecret@postgres/dendrite_publicroomsapi?sslmode=disable" appservice: "postgres://dendrite:itsasecret@postgres/dendrite_appservice?sslmode=disable" # If using naffka you need to specify a naffka database - naffka: "postgres://dendrite:itsasecret@postgres/dendrite_naffka?sslmode=disable" + #naffka: "postgres://dendrite:itsasecret@postgres/dendrite_naffka?sslmode=disable" # The TCP host:port pairs to bind the internal HTTP APIs to. # These shouldn't be exposed to the public internet. @@ -110,11 +110,15 @@ listen: room_server: "room_server:7770" client_api: "client_api:7771" federation_api: "federation_api:7772" + server_key_api: "server_key_api:7778" sync_api: "sync_api:7773" media_api: "media_api:7774" public_rooms_api: "public_rooms_api:7775" federation_sender: "federation_sender:7776" - edu_server: "typing_server:7777" + edu_server: "edu_server:7777" + key_server: "key_server:7779" + user_api: "user_api:7780" + appservice_api: "appservice_api:7781" # The configuration for tracing the dendrite components. tracing: diff --git a/build/docker/docker-compose.deps.yml b/build/docker/docker-compose.deps.yml new file mode 100644 index 000000000..facfc01b3 --- /dev/null +++ b/build/docker/docker-compose.deps.yml @@ -0,0 +1,36 @@ +version: "3.4" +services: + postgres: + hostname: postgres + image: postgres:9.5 + restart: always + volumes: + - ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh + environment: + POSTGRES_PASSWORD: itsasecret + POSTGRES_USER: dendrite + networks: + - internal + + zookeeper: + hostname: zookeeper + image: zookeeper + networks: + - internal + + kafka: + container_name: dendrite_kafka + hostname: kafka + image: wurstmeister/kafka + environment: + KAFKA_ADVERTISED_HOST_NAME: "kafka" + KAFKA_DELETE_TOPIC_ENABLE: "true" + KAFKA_ZOOKEEPER_CONNECT: "zookeeper:2181" + depends_on: + - zookeeper + networks: + - internal + +networks: + internal: + attachable: true diff --git a/build/docker/docker-compose.monolith.yml b/build/docker/docker-compose.monolith.yml new file mode 100644 index 000000000..336a43984 --- /dev/null +++ b/build/docker/docker-compose.monolith.yml @@ -0,0 +1,18 @@ +version: "3.4" +services: + monolith: + hostname: monolith + image: matrixdotorg/dendrite:monolith + command: [ + "--config=dendrite.yaml", + "--tls-cert=server.crt", + "--tls-key=server.key" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + +networks: + internal: + attachable: true diff --git a/build/docker/docker-compose.polylith.yml b/build/docker/docker-compose.polylith.yml new file mode 100644 index 000000000..d424d43b1 --- /dev/null +++ b/build/docker/docker-compose.polylith.yml @@ -0,0 +1,182 @@ +version: "3.4" +services: + client_api_proxy: + hostname: client_api_proxy + image: matrixdotorg/dendrite:clientproxy + command: [ + "--bind-address=:8008", + "--client-api-server-url=http://client_api:7771", + "--sync-api-server-url=http://sync_api:7773", + "--media-api-server-url=http://media_api:7774", + "--public-rooms-api-server-url=http://public_rooms_api:7775" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + depends_on: + - sync_api + - client_api + - media_api + - public_rooms_api + ports: + - "8008:8008" + + client_api: + hostname: client_api + image: matrixdotorg/dendrite:clientapi + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + - room_server + networks: + - internal + + media_api: + hostname: media_api + image: matrixdotorg/dendrite:mediaapi + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + public_rooms_api: + hostname: public_rooms_api + image: matrixdotorg/dendrite:publicroomsapi + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + sync_api: + hostname: sync_api + image: matrixdotorg/dendrite:syncapi + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + room_server: + hostname: room_server + image: matrixdotorg/dendrite:roomserver + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + edu_server: + hostname: edu_server + image: matrixdotorg/dendrite:eduserver + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + federation_api_proxy: + hostname: federation_api_proxy + image: matrixdotorg/dendrite:federationproxy + command: [ + "--bind-address=:8448", + "--federation-api-url=http://federation_api:7772", + "--media-api-server-url=http://media_api:7774" + ] + volumes: + - ./config:/etc/dendrite + depends_on: + - federation_api + - federation_sender + - media_api + networks: + - internal + ports: + - "8448:8448" + + federation_api: + hostname: federation_api + image: matrixdotorg/dendrite:federationapi + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + federation_sender: + hostname: federation_sender + image: matrixdotorg/dendrite:federationsender + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + key_server: + hostname: key_server + image: matrixdotorg/dendrite:keyserver + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + server_key_api: + hostname: server_key_api + image: matrixdotorg/dendrite:serverkeyapi + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + user_api: + hostname: user_api + image: matrixdotorg/dendrite:userapi + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + + appservice_api: + hostname: appservice_api + image: matrixdotorg/dendrite:appservice + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + depends_on: + - room_server + - user_api + +networks: + internal: + attachable: true diff --git a/build/docker/images-build.sh b/build/docker/images-build.sh new file mode 100755 index 000000000..9ee5a09de --- /dev/null +++ b/build/docker/images-build.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +cd $(git rev-parse --show-toplevel) + +docker build -f build/docker/Dockerfile -t matrixdotorg/dendrite:latest . + +docker build -t matrixdotorg/dendrite:monolith --build-arg component=dendrite-monolith-server -f build/docker/Dockerfile.component . + +docker build -t matrixdotorg/dendrite:appservice --build-arg component=dendrite-appservice-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:clientapi --build-arg component=dendrite-client-api-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:clientproxy --build-arg component=client-api-proxy -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:eduserver --build-arg component=dendrite-edu-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:federationapi --build-arg component=dendrite-federation-api-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:federationsender --build-arg component=dendrite-federation-sender-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:federationproxy --build-arg component=federation-api-proxy -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:keyserver --build-arg component=dendrite-key-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:mediaapi --build-arg component=dendrite-media-api-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:publicroomsapi --build-arg component=dendrite-public-rooms-api-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:roomserver --build-arg component=dendrite-room-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:syncapi --build-arg component=dendrite-sync-api-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:serverkeyapi --build-arg component=dendrite-server-key-api-server -f build/docker/Dockerfile.component . +docker build -t matrixdotorg/dendrite:userapi --build-arg component=dendrite-user-api-server -f build/docker/Dockerfile.component . diff --git a/build/docker/images-pull.sh b/build/docker/images-pull.sh new file mode 100755 index 000000000..da08a7325 --- /dev/null +++ b/build/docker/images-pull.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +docker pull matrixdotorg/dendrite:monolith + +docker pull matrixdotorg/dendrite:appservice +docker pull matrixdotorg/dendrite:clientapi +docker pull matrixdotorg/dendrite:clientproxy +docker pull matrixdotorg/dendrite:eduserver +docker pull matrixdotorg/dendrite:federationapi +docker pull matrixdotorg/dendrite:federationsender +docker pull matrixdotorg/dendrite:federationproxy +docker pull matrixdotorg/dendrite:keyserver +docker pull matrixdotorg/dendrite:mediaapi +docker pull matrixdotorg/dendrite:publicroomsapi +docker pull matrixdotorg/dendrite:roomserver +docker pull matrixdotorg/dendrite:syncapi +docker pull matrixdotorg/dendrite:userapi diff --git a/build/docker/images-push.sh b/build/docker/images-push.sh new file mode 100755 index 000000000..1ac60b921 --- /dev/null +++ b/build/docker/images-push.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +docker push matrixdotorg/dendrite:monolith + +docker push matrixdotorg/dendrite:appservice +docker push matrixdotorg/dendrite:clientapi +docker push matrixdotorg/dendrite:clientproxy +docker push matrixdotorg/dendrite:eduserver +docker push matrixdotorg/dendrite:federationapi +docker push matrixdotorg/dendrite:federationsender +docker push matrixdotorg/dendrite:federationproxy +docker push matrixdotorg/dendrite:keyserver +docker push matrixdotorg/dendrite:mediaapi +docker push matrixdotorg/dendrite:publicroomsapi +docker push matrixdotorg/dendrite:roomserver +docker push matrixdotorg/dendrite:syncapi +docker push matrixdotorg/dendrite:serverkeyapi +docker push matrixdotorg/dendrite:userapi diff --git a/docker/postgres/create_db.sh b/build/docker/postgres/create_db.sh similarity index 100% rename from docker/postgres/create_db.sh rename to build/docker/postgres/create_db.sh diff --git a/build/gobind/build.sh b/build/gobind/build.sh new file mode 100644 index 000000000..3a80d374a --- /dev/null +++ b/build/gobind/build.sh @@ -0,0 +1,6 @@ +#!/bin/sh + +gomobile bind -v \ + -ldflags "-X $github.com/yggdrasil-network/yggdrasil-go/src/version.buildName=riot-ios-p2p" \ + -target ios \ + github.com/matrix-org/dendrite/build/gobind \ No newline at end of file diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go new file mode 100644 index 000000000..750babad8 --- /dev/null +++ b/build/gobind/monolith.go @@ -0,0 +1,161 @@ +package gobind + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "time" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/yggconn" + "github.com/matrix-org/dendrite/eduserver" + "github.com/matrix-org/dendrite/eduserver/cache" + "github.com/matrix-org/dendrite/federationsender" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/dendrite/publicroomsapi/storage" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +type DendriteMonolith struct { + StorageDirectory string + listener net.Listener +} + +func (m *DendriteMonolith) BaseURL() string { + return fmt.Sprintf("http://%s", m.listener.Addr().String()) +} + +func (m *DendriteMonolith) Start() { + logger := logrus.Logger{ + Out: BindLogger{}, + } + logrus.SetOutput(BindLogger{}) + + var err error + m.listener, err = net.Listen("tcp", "localhost:65432") + if err != nil { + panic(err) + } + + ygg, err := yggconn.Setup("dendrite", "", m.StorageDirectory) + if err != nil { + panic(err) + } + + cfg := &config.Dendrite{} + cfg.SetDefaults() + cfg.Matrix.ServerName = gomatrixserverlib.ServerName(ygg.DerivedServerName()) + cfg.Matrix.PrivateKey = ygg.SigningPrivateKey() + cfg.Matrix.KeyID = gomatrixserverlib.KeyID(signing.KeyID) + cfg.Kafka.UseNaffka = true + cfg.Kafka.Topics.OutputRoomEvent = "roomserverOutput" + cfg.Kafka.Topics.OutputClientData = "clientapiOutput" + cfg.Kafka.Topics.OutputTypingEvent = "typingServerOutput" + cfg.Kafka.Topics.OutputSendToDeviceEvent = "sendToDeviceOutput" + cfg.Database.Account = config.DataSource(fmt.Sprintf("file:%s/dendrite-account.db", m.StorageDirectory)) + cfg.Database.Device = config.DataSource(fmt.Sprintf("file:%s/dendrite-device.db", m.StorageDirectory)) + cfg.Database.MediaAPI = config.DataSource(fmt.Sprintf("file:%s/dendrite-mediaapi.db", m.StorageDirectory)) + cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s/dendrite-syncapi.db", m.StorageDirectory)) + cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s/dendrite-roomserver.db", m.StorageDirectory)) + cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s/dendrite-serverkey.db", m.StorageDirectory)) + cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s/dendrite-federationsender.db", m.StorageDirectory)) + cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s/dendrite-appservice.db", m.StorageDirectory)) + cfg.Database.PublicRoomsAPI = config.DataSource(fmt.Sprintf("file:%s/dendrite-publicroomsa.db", m.StorageDirectory)) + cfg.Database.Naffka = config.DataSource(fmt.Sprintf("file:%s/dendrite-naffka.db", m.StorageDirectory)) + if err = cfg.Derive(); err != nil { + panic(err) + } + + base := setup.NewBaseDendrite(cfg, "Monolith", false) + defer base.Close() // nolint: errcheck + + accountDB := base.CreateAccountsDB() + deviceDB := base.CreateDeviceDB() + federation := ygg.CreateFederationClient(base) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices) + + rsAPI := roomserver.NewInternalAPI( + base, keyRing, federation, + ) + + eduInputAPI := eduserver.NewInternalAPI( + base, cache.New(), userAPI, + ) + + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + + fsAPI := federationsender.NewInternalAPI( + base, federation, rsAPI, keyRing, + ) + + // The underlying roomserver implementation needs to be able to call the fedsender. + // This is different to rsAPI which can be the http client which doesn't need this dependency + rsAPI.SetFederationSenderAPI(fsAPI) + + publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), base.Cfg.DbProperties(), cfg.Matrix.ServerName) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to public rooms db") + } + + monolith := setup.Monolith{ + Config: base.Cfg, + AccountDB: accountDB, + DeviceDB: deviceDB, + Client: ygg.CreateClient(base), + FedClient: federation, + KeyRing: keyRing, + KafkaConsumer: base.KafkaConsumer, + KafkaProducer: base.KafkaProducer, + + AppserviceAPI: asAPI, + EDUInternalAPI: eduInputAPI, + FederationSenderAPI: fsAPI, + RoomserverAPI: rsAPI, + UserAPI: userAPI, + //ServerKeyAPI: serverKeyAPI, + + PublicRoomsDB: publicRoomsDB, + } + monolith.AddAllPublicRoutes(base.PublicAPIMux) + + httputil.SetupHTTPAPI( + base.BaseMux, + base.PublicAPIMux, + base.InternalAPIMux, + cfg, + base.UseHTTPAPIs, + ) + + // Build both ends of a HTTP multiplex. + httpServer := &http.Server{ + Addr: ":0", + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, + ReadTimeout: 15 * time.Second, + WriteTimeout: 45 * time.Second, + IdleTimeout: 60 * time.Second, + BaseContext: func(_ net.Listener) context.Context { + return context.Background() + }, + Handler: base.BaseMux, + } + + go func() { + logger.Info("Listening on ", ygg.DerivedServerName()) + logger.Fatal(httpServer.Serve(ygg)) + }() + go func() { + logger.Info("Listening on ", m.BaseURL()) + logger.Fatal(httpServer.Serve(m.listener)) + }() +} diff --git a/build/gobind/platform_ios.go b/build/gobind/platform_ios.go new file mode 100644 index 000000000..01f8a6a04 --- /dev/null +++ b/build/gobind/platform_ios.go @@ -0,0 +1,25 @@ +// +build ios + +package gobind + +/* +#cgo CFLAGS: -x objective-c +#cgo LDFLAGS: -framework Foundation +#import +void Log(const char *text) { + NSString *nss = [NSString stringWithUTF8String:text]; + NSLog(@"%@", nss); +} +*/ +import "C" +import "unsafe" + +type BindLogger struct { +} + +func (nsl BindLogger) Write(p []byte) (n int, err error) { + p = append(p, 0) + cstr := (*C.char)(unsafe.Pointer(&p[0])) + C.Log(cstr) + return len(p), nil +} diff --git a/build/gobind/platform_other.go b/build/gobind/platform_other.go new file mode 100644 index 000000000..fdfb13bc0 --- /dev/null +++ b/build/gobind/platform_other.go @@ -0,0 +1,12 @@ +// +build !ios + +package gobind + +import "log" + +type BindLogger struct{} + +func (nsl BindLogger) Write(p []byte) (n int, err error) { + log.Println(string(p)) + return len(p), nil +} diff --git a/scripts/README.md b/build/scripts/README.md similarity index 100% rename from scripts/README.md rename to build/scripts/README.md diff --git a/scripts/build-test-lint.sh b/build/scripts/build-test-lint.sh similarity index 87% rename from scripts/build-test-lint.sh rename to build/scripts/build-test-lint.sh index d2b2b4b16..8f0b775b1 100755 --- a/scripts/build-test-lint.sh +++ b/build/scripts/build-test-lint.sh @@ -10,7 +10,7 @@ set -eu echo "Checking that it builds..." go build ./cmd/... -./scripts/find-lint.sh +./build/scripts/find-lint.sh echo "Testing..." -go test ./... +go test -v ./... diff --git a/scripts/find-lint.sh b/build/scripts/find-lint.sh similarity index 98% rename from scripts/find-lint.sh rename to build/scripts/find-lint.sh index c9663e4e8..7e37e1548 100755 --- a/scripts/find-lint.sh +++ b/build/scripts/find-lint.sh @@ -14,7 +14,7 @@ set -eux -cd `dirname $0`/.. +cd `dirname $0`/../.. args="" if [ ${1:-""} = "fast" ] diff --git a/scripts/install-local-kafka.sh b/build/scripts/install-local-kafka.sh similarity index 100% rename from scripts/install-local-kafka.sh rename to build/scripts/install-local-kafka.sh diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index 87a2f6677..b4c39ae38 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -18,17 +18,13 @@ package auth import ( "context" "crypto/rand" - "database/sql" "encoding/base64" "fmt" "net/http" "strings" - "github.com/matrix-org/dendrite/appservice/types" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) @@ -39,21 +35,13 @@ var tokenByteLength = 32 // DeviceDatabase represents a device database. type DeviceDatabase interface { // Look up the device matching the given access token. - GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error) + GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) } // AccountDatabase represents an account database. type AccountDatabase interface { // Look up the account matching the given localpart. - GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error) -} - -// Data contains information required to authenticate a request. -type Data struct { - AccountDB AccountDatabase - DeviceDB DeviceDatabase - // AppServices is the list of all registered AS - AppServices []config.ApplicationService + GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) } // VerifyUserFromRequest authenticates the HTTP request, @@ -62,8 +50,8 @@ type Data struct { // Note: For an AS user, AS dummy device is returned. // On failure returns an JSON error response which can be sent to the client. func VerifyUserFromRequest( - req *http.Request, data Data, -) (*authtypes.Device, *util.JSONResponse) { + req *http.Request, userAPI api.UserInternalAPI, +) (*api.Device, *util.JSONResponse) { // Try to find the Application Service user token, err := ExtractAccessToken(req) if err != nil { @@ -72,105 +60,31 @@ func VerifyUserFromRequest( JSON: jsonerror.MissingToken(err.Error()), } } - - // Search for app service with given access_token - var appService *config.ApplicationService - for _, as := range data.AppServices { - if as.ASToken == token { - appService = &as - break - } + var res api.QueryAccessTokenResponse + err = userAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{ + AccessToken: token, + AppServiceUserID: req.URL.Query().Get("user_id"), + }, &res) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") + jsonErr := jsonerror.InternalServerError() + return nil, &jsonErr } - - if appService != nil { - // Create a dummy device for AS user - dev := authtypes.Device{ - // Use AS dummy device ID - ID: types.AppServiceDeviceID, - // AS dummy device has AS's token. - AccessToken: token, - } - - userID := req.URL.Query().Get("user_id") - localpart, err := userutil.ParseUsernameParam(userID, nil) - if err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(err.Error()), - } - } - - if localpart != "" { // AS is masquerading as another user - // Verify that the user is registered - account, err := data.AccountDB.GetAccountByLocalpart(req.Context(), localpart) - // Verify that account exists & appServiceID matches - if err == nil && account.AppServiceID == appService.ID { - // Set the userID of dummy device - dev.UserID = userID - return &dev, nil - } - + if res.Err != nil { + if forbidden, ok := res.Err.(*api.ErrorForbidden); ok { return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Application service has not registered this user"), + JSON: jsonerror.Forbidden(forbidden.Message), } } - - // AS is not masquerading as any user, so use AS's sender_localpart - dev.UserID = appService.SenderLocalpart - return &dev, nil } - - // Try to find local user from device database - dev, devErr := verifyAccessToken(req, data.DeviceDB) - if devErr == nil { - return dev, verifyUserParameters(req) - } - - return nil, &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: jsonerror.UnknownToken("Unrecognized access token"), // nolint: misspell - } -} - -// verifyUserParameters ensures that a request coming from a regular user is not -// using any query parameters reserved for an application service -func verifyUserParameters(req *http.Request) *util.JSONResponse { - if req.URL.Query().Get("ts") != "" { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("parameter 'ts' not allowed without valid parameter 'access_token'"), - } - } - return nil -} - -// verifyAccessToken verifies that an access token was supplied in the given HTTP request -// and returns the device it corresponds to. Returns resErr (an error response which can be -// sent to the client) if the token is invalid or there was a problem querying the database. -func verifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *authtypes.Device, resErr *util.JSONResponse) { - token, err := ExtractAccessToken(req) - if err != nil { - resErr = &util.JSONResponse{ + if res.Device == nil { + return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.MissingToken(err.Error()), - } - return - } - device, err = deviceDB.GetDeviceByAccessToken(req.Context(), token) - if err != nil { - if err == sql.ErrNoRows { - resErr = &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: jsonerror.UnknownToken("Unknown token"), - } - } else { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByAccessToken failed") - jsonErr := jsonerror.InternalServerError() - resErr = &jsonErr + JSON: jsonerror.UnknownToken("Unknown token"), } } - return + return res.Device, nil } // GenerateAccessToken creates a new access token. Returns an error if failed to generate diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 1339f7c8c..174eb1bf1 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,60 +15,55 @@ package clientapi import ( + "github.com/Shopify/sarama" + "github.com/gorilla/mux" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/consumers" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/routing" - "github.com/matrix-org/dendrite/common/basecomponent" - "github.com/matrix-org/dendrite/common/transactions" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/transactions" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) -// SetupClientAPIComponent sets up and registers HTTP handlers for the ClientAPI -// component. -func SetupClientAPIComponent( - base *basecomponent.BaseDendrite, +// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. +func AddPublicRoutes( + router *mux.Router, + cfg *config.Dendrite, + consumer sarama.Consumer, + producer sarama.SyncProducer, deviceDB devices.Database, accountsDB accounts.Database, federation *gomatrixserverlib.FederationClient, - keyRing *gomatrixserverlib.KeyRing, - aliasAPI roomserverAPI.RoomserverAliasAPI, - inputAPI roomserverAPI.RoomserverInputAPI, - queryAPI roomserverAPI.RoomserverQueryAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, eduInputAPI eduServerAPI.EDUServerInputAPI, asAPI appserviceAPI.AppServiceQueryAPI, transactionsCache *transactions.Cache, - fedSenderAPI federationSenderAPI.FederationSenderQueryAPI, + fsAPI federationSenderAPI.FederationSenderInternalAPI, + userAPI userapi.UserInternalAPI, ) { - roomserverProducer := producers.NewRoomserverProducer(inputAPI, queryAPI) - eduProducer := producers.NewEDUServerProducer(eduInputAPI) - - userUpdateProducer := &producers.UserUpdateProducer{ - Producer: base.KafkaProducer, - Topic: string(base.Cfg.Kafka.Topics.UserUpdates), - } - syncProducer := &producers.SyncAPIProducer{ - Producer: base.KafkaProducer, - Topic: string(base.Cfg.Kafka.Topics.OutputClientData), + Producer: producer, + Topic: string(cfg.Kafka.Topics.OutputClientData), } - consumer := consumers.NewOutputRoomEventConsumer( - base.Cfg, base.KafkaConsumer, accountsDB, queryAPI, + roomEventConsumer := consumers.NewOutputRoomEventConsumer( + cfg, consumer, accountsDB, rsAPI, ) - if err := consumer.Start(); err != nil { + if err := roomEventConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") } routing.Setup( - base.APIMux, base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI, - accountsDB, deviceDB, federation, *keyRing, userUpdateProducer, - syncProducer, eduProducer, transactionsCache, fedSenderAPI, + router, cfg, eduInputAPI, rsAPI, asAPI, + accountsDB, deviceDB, userAPI, federation, + syncProducer, transactionsCache, fsAPI, ) } diff --git a/clientapi/consumers/roomserver.go b/clientapi/consumers/roomserver.go index 6d5bb09a6..beeda042b 100644 --- a/clientapi/consumers/roomserver.go +++ b/clientapi/consumers/roomserver.go @@ -18,22 +18,22 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" + "github.com/Shopify/sarama" log "github.com/sirupsen/logrus" - sarama "gopkg.in/Shopify/sarama.v1" ) // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { - roomServerConsumer *common.ContinualConsumer - db accounts.Database - query api.RoomserverQueryAPI - serverName string + rsAPI api.RoomserverInternalAPI + rsConsumer *internal.ContinualConsumer + db accounts.Database + serverName string } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -41,19 +41,19 @@ func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, store accounts.Database, - queryAPI api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, ) *OutputRoomEventConsumer { - consumer := common.ContinualConsumer{ + consumer := internal.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputRoomEvent), Consumer: kafkaConsumer, PartitionStore: store, } s := &OutputRoomEventConsumer{ - roomServerConsumer: &consumer, - db: store, - query: queryAPI, - serverName: string(cfg.Matrix.ServerName), + rsConsumer: &consumer, + db: store, + rsAPI: rsAPI, + serverName: string(cfg.Matrix.ServerName), } consumer.ProcessMessage = s.onMessage @@ -62,7 +62,7 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - return s.roomServerConsumer.Start() + return s.rsConsumer.Start() } // onMessage is called when the sync server receives a new event from the room server output log. @@ -84,63 +84,9 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { return nil } - ev := output.NewRoomEvent.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "type": ev.Type(), - }).Info("received event from roomserver") - - events, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev.Event) - if err != nil { - return err - } - - return s.db.UpdateMemberships(context.TODO(), events, output.NewRoomEvent.RemovesStateEventIDs) -} - -// lookupStateEvents looks up the state events that are added by a new event. -func (s *OutputRoomEventConsumer) lookupStateEvents( - addsStateEventIDs []string, event gomatrixserverlib.Event, -) ([]gomatrixserverlib.Event, error) { - // Fast path if there aren't any new state events. - if len(addsStateEventIDs) == 0 { - // If the event is a membership update (e.g. for a profile update), it won't - // show up in AddsStateEventIDs, so we need to add it manually - if event.Type() == "m.room.member" { - return []gomatrixserverlib.Event{event}, nil - } - return nil, nil - } - - // Fast path if the only state event added is the event itself. - if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { - return []gomatrixserverlib.Event{event}, nil - } - - result := []gomatrixserverlib.Event{} - missing := []string{} - for _, id := range addsStateEventIDs { - // Append the current event in the results if its ID is in the events list - if id == event.EventID() { - result = append(result, event) - } else { - // If the event isn't the current one, add it to the list of events - // to retrieve from the roomserver - missing = append(missing, id) - } - } - - // Request the missing events from the roomserver - eventReq := api.QueryEventsByIDRequest{EventIDs: missing} - var eventResp api.QueryEventsByIDResponse - if err := s.query.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { - return nil, err - } - - for _, headeredEvent := range eventResp.Events { - result = append(result, headeredEvent.Event) - } - - return result, nil + return s.db.UpdateMemberships( + context.TODO(), + gomatrixserverlib.UnwrapEventHeaders(output.NewRoomEvent.AddsState()), + output.NewRoomEvent.RemovesStateEventIDs, + ) } diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go index 735de5bea..85e887aec 100644 --- a/clientapi/jsonerror/jsonerror.go +++ b/clientapi/jsonerror/jsonerror.go @@ -18,6 +18,7 @@ import ( "fmt" "net/http" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -124,6 +125,12 @@ func GuestAccessForbidden(msg string) *MatrixError { return &MatrixError{"M_GUEST_ACCESS_FORBIDDEN", msg} } +// IncompatibleRoomVersion is an error which is returned when the client +// requests a room with a version that is unsupported. +func IncompatibleRoomVersion(roomVersion gomatrixserverlib.RoomVersion) *MatrixError { + return &MatrixError{"M_INCOMPATIBLE_ROOM_VERSION", string(roomVersion)} +} + // UnsupportedRoomVersion is an error which is returned when the client // requests a room with a version that is unsupported. func UnsupportedRoomVersion(msg string) *MatrixError { diff --git a/clientapi/producers/eduserver.go b/clientapi/producers/eduserver.go deleted file mode 100644 index 30c40fb7f..000000000 --- a/clientapi/producers/eduserver.go +++ /dev/null @@ -1,54 +0,0 @@ -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "context" - "time" - - "github.com/matrix-org/dendrite/eduserver/api" - "github.com/matrix-org/gomatrixserverlib" -) - -// EDUServerProducer produces events for the EDU server to consume -type EDUServerProducer struct { - InputAPI api.EDUServerInputAPI -} - -// NewEDUServerProducer creates a new EDUServerProducer -func NewEDUServerProducer(inputAPI api.EDUServerInputAPI) *EDUServerProducer { - return &EDUServerProducer{ - InputAPI: inputAPI, - } -} - -// SendTyping sends a typing event to EDU server -func (p *EDUServerProducer) SendTyping( - ctx context.Context, userID, roomID string, - typing bool, timeoutMS int64, -) error { - requestData := api.InputTypingEvent{ - UserID: userID, - RoomID: roomID, - Typing: typing, - TimeoutMS: timeoutMS, - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), - } - - var response api.InputTypingEventResponse - err := p.InputAPI.InputTypingEvent( - ctx, &api.InputTypingEventRequest{InputTypingEvent: requestData}, &response, - ) - - return err -} diff --git a/clientapi/producers/roomserver.go b/clientapi/producers/roomserver.go deleted file mode 100644 index 391ea07bf..000000000 --- a/clientapi/producers/roomserver.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "context" - - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" -) - -// RoomserverProducer produces events for the roomserver to consume. -type RoomserverProducer struct { - InputAPI api.RoomserverInputAPI - QueryAPI api.RoomserverQueryAPI -} - -// NewRoomserverProducer creates a new RoomserverProducer -func NewRoomserverProducer(inputAPI api.RoomserverInputAPI, queryAPI api.RoomserverQueryAPI) *RoomserverProducer { - return &RoomserverProducer{ - InputAPI: inputAPI, - QueryAPI: queryAPI, - } -} - -// SendEvents writes the given events to the roomserver input log. The events are written with KindNew. -func (c *RoomserverProducer) SendEvents( - ctx context.Context, events []gomatrixserverlib.HeaderedEvent, sendAsServer gomatrixserverlib.ServerName, - txnID *api.TransactionID, -) (string, error) { - ires := make([]api.InputRoomEvent, len(events)) - for i, event := range events { - ires[i] = api.InputRoomEvent{ - Kind: api.KindNew, - Event: event, - AuthEventIDs: event.AuthEventIDs(), - SendAsServer: string(sendAsServer), - TransactionID: txnID, - } - } - return c.SendInputRoomEvents(ctx, ires) -} - -// SendEventWithState writes an event with KindNew to the roomserver input log -// with the state at the event as KindOutlier before it. -func (c *RoomserverProducer) SendEventWithState( - ctx context.Context, state gomatrixserverlib.RespState, event gomatrixserverlib.HeaderedEvent, -) error { - outliers, err := state.Events() - if err != nil { - return err - } - - var ires []api.InputRoomEvent - for _, outlier := range outliers { - ires = append(ires, api.InputRoomEvent{ - Kind: api.KindOutlier, - Event: outlier.Headered(event.RoomVersion), - AuthEventIDs: outlier.AuthEventIDs(), - }) - } - - stateEventIDs := make([]string, len(state.StateEvents)) - for i := range state.StateEvents { - stateEventIDs[i] = state.StateEvents[i].EventID() - } - - ires = append(ires, api.InputRoomEvent{ - Kind: api.KindNew, - Event: event, - AuthEventIDs: event.AuthEventIDs(), - HasState: true, - StateEventIDs: stateEventIDs, - }) - - _, err = c.SendInputRoomEvents(ctx, ires) - return err -} - -// SendInputRoomEvents writes the given input room events to the roomserver input API. -func (c *RoomserverProducer) SendInputRoomEvents( - ctx context.Context, ires []api.InputRoomEvent, -) (eventID string, err error) { - request := api.InputRoomEventsRequest{InputRoomEvents: ires} - var response api.InputRoomEventsResponse - err = c.InputAPI.InputRoomEvents(ctx, &request, &response) - eventID = response.EventID - return -} - -// SendInvite writes the invite event to the roomserver input API. -// This should only be needed for invite events that occur outside of a known room. -// If we are in the room then the event should be sent using the SendEvents method. -func (c *RoomserverProducer) SendInvite( - ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, - inviteRoomState []gomatrixserverlib.InviteV2StrippedState, -) error { - request := api.InputRoomEventsRequest{ - InputInviteEvents: []api.InputInviteEvent{{ - Event: inviteEvent, - InviteRoomState: inviteRoomState, - RoomVersion: inviteEvent.RoomVersion, - }}, - } - var response api.InputRoomEventsResponse - return c.InputAPI.InputRoomEvents(ctx, &request, &response) -} diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index 6bfcd51aa..6ab8eef28 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -17,9 +17,9 @@ package producers import ( "encoding/json" - "github.com/matrix-org/dendrite/common" - - sarama "gopkg.in/Shopify/sarama.v1" + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal/eventutil" + log "github.com/sirupsen/logrus" ) // SyncAPIProducer produces events for the sync API server to consume @@ -32,7 +32,7 @@ type SyncAPIProducer struct { func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error { var m sarama.ProducerMessage - data := common.AccountData{ + data := eventutil.AccountData{ RoomID: roomID, Type: dataType, } @@ -44,6 +44,11 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string m.Topic = string(p.Topic) m.Key = sarama.StringEncoder(userID) m.Value = sarama.ByteEncoder(value) + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": roomID, + "data_type": dataType, + }).Infof("Producing to topic '%s'", p.Topic) _, _, err = p.Producer.SendMessage(&m) return err diff --git a/clientapi/producers/userupdate.go b/clientapi/producers/userupdate.go deleted file mode 100644 index 2a5dfc70a..000000000 --- a/clientapi/producers/userupdate.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "encoding/json" - - sarama "gopkg.in/Shopify/sarama.v1" -) - -// UserUpdateProducer produces events related to user updates. -type UserUpdateProducer struct { - Topic string - Producer sarama.SyncProducer -} - -// TODO: Move this struct to `common` so the components that consume the topic -// can use it when parsing incoming messages -type profileUpdate struct { - Updated string `json:"updated"` // Which attribute is updated (can be either `avatar_url` or `displayname`) - OldValue string `json:"old_value"` // The attribute's value before the update - NewValue string `json:"new_value"` // The attribute's value after the update -} - -// SendUpdate sends an update using kafka to notify the roomserver of the -// profile update. Returns an error if the update failed to send. -func (p *UserUpdateProducer) SendUpdate( - userID string, updatedAttribute string, oldValue string, newValue string, -) error { - var update profileUpdate - var m sarama.ProducerMessage - - m.Topic = string(p.Topic) - m.Key = sarama.StringEncoder(userID) - - update = profileUpdate{ - Updated: updatedAttribute, - OldValue: oldValue, - NewValue: newValue, - } - - value, err := json.Marshal(update) - if err != nil { - return err - } - m.Value = sarama.ByteEncoder(value) - - _, _, err = p.Producer.SendMessage(&m) - return err -} diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index a5d53c326..d5fafedb1 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -16,21 +16,20 @@ package routing import ( "encoding/json" + "fmt" "io/ioutil" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} func GetAccountData( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, ) util.JSONResponse { if userID != device.UserID { @@ -40,15 +39,25 @@ func GetAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, + } + dataRes := api.QueryAccountDataResponse{} + if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") + return util.ErrorResponse(fmt.Errorf("userAPI.QueryAccountData: %w", err)) } - if data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, dataType, - ); err == nil { + var data json.RawMessage + var ok bool + if roomID != "" { + data, ok = dataRes.RoomAccountData[roomID][dataType] + } else { + data, ok = dataRes.GlobalAccountData[dataType] + } + if ok { return util.JSONResponse{ Code: http.StatusOK, JSON: data, @@ -63,7 +72,7 @@ func GetAccountData( // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} func SaveAccountData( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { if userID != device.UserID { @@ -73,12 +82,6 @@ func SaveAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - defer req.Body.Close() // nolint: errcheck if req.Body == http.NoBody { @@ -101,13 +104,19 @@ func SaveAccountData( } } - if err := accountDB.SaveAccountData( - req.Context(), localpart, roomID, dataType, string(body), - ); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed") - return jsonerror.InternalServerError() + dataReq := api.InputAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, + AccountData: json.RawMessage(body), + } + dataRes := api.InputAccountDataResponse{} + if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") + return util.ErrorResponse(err) } + // TODO: user API should do this since it's account data if err := syncProducer.SendData(userID, roomID, dataType); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index 8cb6b3d9b..b7f2cd6d3 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/util" ) diff --git a/clientapi/routing/capabilities.go b/clientapi/routing/capabilities.go index 1792c6308..199b15240 100644 --- a/clientapi/routing/capabilities.go +++ b/clientapi/routing/capabilities.go @@ -26,11 +26,11 @@ import ( // SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite) // by building a m.room.member event then sending it to the room server func GetCapabilities( - req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI, + req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { roomVersionsQueryReq := roomserverAPI.QueryRoomVersionCapabilitiesRequest{} roomVersionsQueryRes := roomserverAPI.QueryRoomVersionCapabilitiesResponse{} - if err := queryAPI.QueryRoomVersionCapabilities( + if err := rsAPI.QueryRoomVersionCapabilities( req.Context(), &roomVersionsQueryReq, &roomVersionsQueryRes, diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index ef11e8b3e..8682b03a4 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -24,14 +24,14 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverVersion "github.com/matrix-org/dendrite/roomserver/version" + "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/clientapi/threepid" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -98,7 +98,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { // Validate creation_content fields defined in the spec by marshalling the // creation_content map into bytes and then unmarshalling the bytes into - // common.CreateContent. + // eventutil.CreateContent. creationContentBytes, err := json.Marshal(r.CreationContent) if err != nil { @@ -135,23 +135,23 @@ type fledglingEvent struct { // CreateRoom implements /createRoom func CreateRoom( - req *http.Request, device *authtypes.Device, - cfg *config.Dendrite, producer *producers.RoomserverProducer, - accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, + req *http.Request, device *api.Device, + cfg *config.Dendrite, + accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) - return createRoom(req, device, cfg, roomID, producer, accountDB, aliasAPI, asAPI) + return createRoom(req, device, cfg, roomID, accountDB, rsAPI, asAPI) } // createRoom implements /createRoom // nolint: gocyclo func createRoom( - req *http.Request, device *authtypes.Device, - cfg *config.Dendrite, roomID string, producer *producers.RoomserverProducer, - accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, + req *http.Request, device *api.Device, + cfg *config.Dendrite, roomID string, + accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { logger := util.GetLogger(req.Context()) @@ -212,6 +212,25 @@ func createRoom( return jsonerror.InternalServerError() } + var roomAlias string + if r.RoomAliasName != "" { + roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, cfg.Matrix.ServerName) + // check it's free TODO: This races but is better than nothing + hasAliasReq := roomserverAPI.GetRoomIDForAliasRequest{ + Alias: roomAlias, + } + + var aliasResp roomserverAPI.GetRoomIDForAliasResponse + err = rsAPI.GetRoomIDForAlias(req.Context(), &hasAliasReq, &aliasResp) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") + return jsonerror.InternalServerError() + } + if aliasResp.RoomID != "" { + return util.MessageResponse(400, "Alias already exists") + } + } + membershipContent := gomatrixserverlib.MemberContent{ Membership: gomatrixserverlib.Join, DisplayName: profile.DisplayName, @@ -243,9 +262,9 @@ func createRoom( // 1- m.room.create // 2- room creator join member // 3- m.room.power_levels - // 4- m.room.canonical_alias (opt) TODO - // 5- m.room.join_rules - // 6- m.room.history_visibility + // 4- m.room.join_rules + // 5- m.room.history_visibility + // 6- m.room.canonical_alias (opt) // 7- m.room.guest_access (opt) // 8- other initial state items // 9- m.room.name (opt) @@ -260,24 +279,28 @@ func createRoom( eventsToMake := []fledglingEvent{ {"m.room.create", "", r.CreationContent}, {"m.room.member", userID, membershipContent}, - {"m.room.power_levels", "", common.InitialPowerLevelsContent(userID)}, - // TODO: m.room.canonical_alias + {"m.room.power_levels", "", eventutil.InitialPowerLevelsContent(userID)}, {"m.room.join_rules", "", gomatrixserverlib.JoinRuleContent{JoinRule: joinRules}}, - {"m.room.history_visibility", "", common.HistoryVisibilityContent{HistoryVisibility: historyVisibility}}, + {"m.room.history_visibility", "", eventutil.HistoryVisibilityContent{HistoryVisibility: historyVisibility}}, + } + if roomAlias != "" { + // TODO: bit of a chicken and egg problem here as the alias doesn't exist and cannot until we have made the room. + // This means we might fail creating the alias but say the canonical alias is something that doesn't exist. + // m.room.aliases is handled when we call roomserver.SetRoomAlias + eventsToMake = append(eventsToMake, fledglingEvent{"m.room.canonical_alias", "", eventutil.CanonicalAlias{Alias: roomAlias}}) } if r.GuestCanJoin { - eventsToMake = append(eventsToMake, fledglingEvent{"m.room.guest_access", "", common.GuestAccessContent{GuestAccess: "can_join"}}) + eventsToMake = append(eventsToMake, fledglingEvent{"m.room.guest_access", "", eventutil.GuestAccessContent{GuestAccess: "can_join"}}) } eventsToMake = append(eventsToMake, r.InitialState...) if r.Name != "" { - eventsToMake = append(eventsToMake, fledglingEvent{"m.room.name", "", common.NameContent{Name: r.Name}}) + eventsToMake = append(eventsToMake, fledglingEvent{"m.room.name", "", eventutil.NameContent{Name: r.Name}}) } if r.Topic != "" { - eventsToMake = append(eventsToMake, fledglingEvent{"m.room.topic", "", common.TopicContent{Topic: r.Topic}}) + eventsToMake = append(eventsToMake, fledglingEvent{"m.room.topic", "", eventutil.TopicContent{Topic: r.Topic}}) } // TODO: invite events // TODO: 3pid invite events - // TODO: m.room.aliases authEvents := gomatrixserverlib.NewAuthEvents(nil) for i, e := range eventsToMake { @@ -320,19 +343,16 @@ func createRoom( } // send events to the room server - _, err = producer.SendEvents(req.Context(), builtEvents, cfg.Matrix.ServerName, nil) + _, err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed") + util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } // TODO(#269): Reserve room alias while we create the room. This stops us // from creating the room but still failing due to the alias having already // been taken. - var roomAlias string - if r.RoomAliasName != "" { - roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, cfg.Matrix.ServerName) - + if roomAlias != "" { aliasReq := roomserverAPI.SetRoomAliasRequest{ Alias: roomAlias, RoomID: roomID, @@ -340,7 +360,7 @@ func createRoom( } var aliasResp roomserverAPI.SetRoomAliasResponse - err = aliasAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp) + err = rsAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") return jsonerror.InternalServerError() @@ -351,6 +371,50 @@ func createRoom( } } + // If this is a direct message then we should invite the participants. + for _, invitee := range r.Invite { + // Build the membership request. + body := threepid.MembershipRequest{ + UserID: invitee, + } + // Build the invite event. + inviteEvent, err := buildMembershipEvent( + req.Context(), body, accountDB, device, gomatrixserverlib.Invite, + roomID, true, cfg, evTime, rsAPI, asAPI, + ) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") + continue + } + // Build some stripped state for the invite. + candidates := append(gomatrixserverlib.UnwrapEventHeaders(builtEvents), *inviteEvent) + var strippedState []gomatrixserverlib.InviteV2StrippedState + for _, event := range candidates { + switch event.Type() { + // TODO: case gomatrixserverlib.MRoomEncryption: + // fallthrough + case gomatrixserverlib.MRoomMember: + fallthrough + case gomatrixserverlib.MRoomJoinRules: + strippedState = append( + strippedState, + gomatrixserverlib.NewInviteV2StrippedState(&event), + ) + } + } + // Send the invite event to the roomserver. + if perr := roomserverAPI.SendInvite( + req.Context(), rsAPI, + inviteEvent.Headered(roomVersion), + strippedState, // invite room state + cfg.Matrix.ServerName, // send as server + nil, // transaction ID + ); perr != nil { + util.GetLogger(req.Context()).WithError(perr).Error("SendInvite failed") + return perr.JSONResponse() + } + } + response := createRoomResponse{ RoomID: roomID, RoomAlias: roomAlias, diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 89c394913..51a15a882 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -19,16 +19,19 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) +// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices type deviceJSON struct { - DeviceID string `json:"device_id"` - UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + DisplayName string `json:"display_name"` + LastSeenIP string `json:"last_seen_ip"` + LastSeenTS uint64 `json:"last_seen_ts"` } type devicesJSON struct { @@ -45,7 +48,7 @@ type devicesDeleteJSON struct { // GetDeviceByID handles /devices/{deviceID} func GetDeviceByID( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, deviceID string, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -70,14 +73,13 @@ func GetDeviceByID( Code: http.StatusOK, JSON: deviceJSON{ DeviceID: dev.ID, - UserID: dev.UserID, }, } } // GetDevicesByLocalpart handles /devices func GetDevicesByLocalpart( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -98,7 +100,6 @@ func GetDevicesByLocalpart( for _, dev := range deviceList { res.Devices = append(res.Devices, deviceJSON{ DeviceID: dev.ID, - UserID: dev.UserID, }) } @@ -110,7 +111,7 @@ func GetDevicesByLocalpart( // UpdateDeviceByID handles PUT on /devices/{deviceID} func UpdateDeviceByID( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, deviceID string, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -160,7 +161,7 @@ func UpdateDeviceByID( // DeleteDeviceById handles DELETE requests to /devices/{deviceId} func DeleteDeviceById( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, deviceID string, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -185,7 +186,7 @@ func DeleteDeviceById( // DeleteDevices handles POST requests to /delete_devices func DeleteDevices( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 248696ab2..0dc4d5605 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -18,12 +18,12 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -46,8 +46,8 @@ func DirectoryRoom( roomAlias string, federation *gomatrixserverlib.FederationClient, cfg *config.Dendrite, - rsAPI roomserverAPI.RoomserverAliasAPI, - fedSenderAPI federationSenderAPI.FederationSenderQueryAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, + fedSenderAPI federationSenderAPI.FederationSenderInternalAPI, ) util.JSONResponse { _, domain, err := gomatrixserverlib.SplitID('#', roomAlias) if err != nil { @@ -77,7 +77,7 @@ func DirectoryRoom( if fedErr != nil { // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. - util.GetLogger(req.Context()).WithError(err).Error("federation.LookupRoomAlias failed") + util.GetLogger(req.Context()).WithError(fedErr).Error("federation.LookupRoomAlias failed") return jsonerror.InternalServerError() } res.RoomID = fedRes.RoomID @@ -112,10 +112,10 @@ func DirectoryRoom( // TODO: Check if the user has the power level to set an alias func SetLocalAlias( req *http.Request, - device *authtypes.Device, + device *api.Device, alias string, cfg *config.Dendrite, - aliasAPI roomserverAPI.RoomserverAliasAPI, + aliasAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { _, domain, err := gomatrixserverlib.SplitID('#', alias) if err != nil { @@ -188,9 +188,9 @@ func SetLocalAlias( // RemoveLocalAlias implements DELETE /directory/room/{roomAlias} func RemoveLocalAlias( req *http.Request, - device *authtypes.Device, + device *api.Device, alias string, - aliasAPI roomserverAPI.RoomserverAliasAPI, + aliasAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { creatorQueryReq := roomserverAPI.GetCreatorIDForAliasRequest{ diff --git a/clientapi/routing/filter.go b/clientapi/routing/filter.go index 505e09279..6520e6e40 100644 --- a/clientapi/routing/filter.go +++ b/clientapi/routing/filter.go @@ -17,17 +17,17 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} func GetFilter( - req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, filterID string, + req *http.Request, device *api.Device, accountDB accounts.Database, userID string, filterID string, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ @@ -64,7 +64,7 @@ type filterResponse struct { //PutFilter implements POST /_matrix/client/r0/user/{userId}/filter func PutFilter( - req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, + req *http.Request, device *api.Device, accountDB accounts.Database, userID string, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ diff --git a/clientapi/routing/getevent.go b/clientapi/routing/getevent.go index 2d3152510..2a51db730 100644 --- a/clientapi/routing/getevent.go +++ b/clientapi/routing/getevent.go @@ -17,22 +17,21 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) type getEventRequest struct { req *http.Request - device *authtypes.Device + device *userapi.Device roomID string eventID string cfg *config.Dendrite federation *gomatrixserverlib.FederationClient - keyRing gomatrixserverlib.KeyRing requestedEvent gomatrixserverlib.Event } @@ -40,19 +39,18 @@ type getEventRequest struct { // https://matrix.org/docs/spec/client_server/r0.4.0.html#get-matrix-client-r0-rooms-roomid-event-eventid func GetEvent( req *http.Request, - device *authtypes.Device, + device *userapi.Device, roomID string, eventID string, cfg *config.Dendrite, - queryAPI api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, federation *gomatrixserverlib.FederationClient, - keyRing gomatrixserverlib.KeyRing, ) util.JSONResponse { eventsReq := api.QueryEventsByIDRequest{ EventIDs: []string{eventID}, } var eventsResp api.QueryEventsByIDResponse - err := queryAPI.QueryEventsByID(req.Context(), &eventsReq, &eventsResp) + err := rsAPI.QueryEventsByID(req.Context(), &eventsReq, &eventsResp) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryEventsByID failed") return jsonerror.InternalServerError() @@ -75,7 +73,6 @@ func GetEvent( eventID: eventID, cfg: cfg, federation: federation, - keyRing: keyRing, requestedEvent: requestedEvent, } @@ -88,7 +85,7 @@ func GetEvent( }}, } var stateResp api.QueryStateAfterEventsResponse - if err := queryAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil { + if err := rsAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil { util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryStateAfterEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index f72bb9162..cb68fe196 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -15,434 +15,64 @@ package routing import ( - "fmt" "net/http" - "strings" - "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrix" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/sirupsen/logrus" ) -// JoinRoomByIDOrAlias implements the "/join/{roomIDOrAlias}" API. -// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-join-roomidoralias func JoinRoomByIDOrAlias( req *http.Request, - device *authtypes.Device, - roomIDOrAlias string, - cfg *config.Dendrite, - federation *gomatrixserverlib.FederationClient, - producer *producers.RoomserverProducer, - queryAPI roomserverAPI.RoomserverQueryAPI, - aliasAPI roomserverAPI.RoomserverAliasAPI, - keyRing gomatrixserverlib.KeyRing, + device *api.Device, + rsAPI roomserverAPI.RoomserverInternalAPI, accountDB accounts.Database, + roomIDOrAlias string, ) util.JSONResponse { - var content map[string]interface{} // must be a JSON object - if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil { - return *resErr + // Prepare to ask the roomserver to perform the room join. + joinReq := roomserverAPI.PerformJoinRequest{ + RoomIDOrAlias: roomIDOrAlias, + UserID: device.UserID, + Content: map[string]interface{}{}, } + joinRes := roomserverAPI.PerformJoinResponse{} - evTime, err := httputil.ParseTSParam(req) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), - } - } + // If content was provided in the request then incude that + // in the request. It'll get used as a part of the membership + // event content. + _ = httputil.UnmarshalJSONRequest(req, &joinReq.Content) + // Work out our localpart for the client profile request. localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed") - return jsonerror.InternalServerError() - } - - content["membership"] = gomatrixserverlib.Join - content["displayname"] = profile.DisplayName - content["avatar_url"] = profile.AvatarURL - - r := joinRoomReq{ - req, evTime, content, device.UserID, cfg, federation, producer, queryAPI, aliasAPI, keyRing, - } - - if strings.HasPrefix(roomIDOrAlias, "!") { - return r.joinRoomByID(roomIDOrAlias) - } - if strings.HasPrefix(roomIDOrAlias, "#") { - return r.joinRoomByAlias(roomIDOrAlias) - } - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON( - fmt.Sprintf("Invalid first character '%s' for room ID or alias", - string([]rune(roomIDOrAlias)[0])), // Wrapping with []rune makes this call UTF-8 safe - ), - } -} - -type joinRoomReq struct { - req *http.Request - evTime time.Time - content map[string]interface{} - userID string - cfg *config.Dendrite - federation *gomatrixserverlib.FederationClient - producer *producers.RoomserverProducer - queryAPI roomserverAPI.RoomserverQueryAPI - aliasAPI roomserverAPI.RoomserverAliasAPI - keyRing gomatrixserverlib.KeyRing -} - -// joinRoomByID joins a room by room ID -func (r joinRoomReq) joinRoomByID(roomID string) util.JSONResponse { - // A client should only join a room by room ID when it has an invite - // to the room. If the server is already in the room then we can - // lookup the invite and process the request as a normal state event. - // If the server is not in the room the we will need to look up the - // remote server the invite came from in order to request a join event - // from that server. - queryReq := roomserverAPI.QueryInvitesForUserRequest{ - RoomID: roomID, TargetUserID: r.userID, - } - var queryRes roomserverAPI.QueryInvitesForUserResponse - if err := r.queryAPI.QueryInvitesForUser(r.req.Context(), &queryReq, &queryRes); err != nil { - util.GetLogger(r.req.Context()).WithError(err).Error("r.queryAPI.QueryInvitesForUser failed") - return jsonerror.InternalServerError() - } - - servers := []gomatrixserverlib.ServerName{} - seenInInviterIDs := map[gomatrixserverlib.ServerName]bool{} - for _, userID := range queryRes.InviteSenderUserIDs { - _, domain, err := gomatrixserverlib.SplitID('@', userID) + } else { + // Request our profile content to populate the request content with. + var profile *authtypes.Profile + profile, err = accountDB.GetProfileByLocalpart(req.Context(), localpart) if err != nil { - util.GetLogger(r.req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - if !seenInInviterIDs[domain] { - servers = append(servers, domain) - seenInInviterIDs[domain] = true + util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed") + } else { + joinReq.Content["displayname"] = profile.DisplayName + joinReq.Content["avatar_url"] = profile.AvatarURL } } - // Also add the domain extracted from the roomID as a last resort to join - // in case the client is erroneously trying to join by ID without an invite - // or all previous attempts at domains extracted from the inviter IDs fail - // Note: It's no guarantee we'll succeed because a room isn't bound to the domain in its ID - _, domain, err := gomatrixserverlib.SplitID('!', roomID) - if err != nil { - util.GetLogger(r.req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - if domain != r.cfg.Matrix.ServerName && !seenInInviterIDs[domain] { - servers = append(servers, domain) + // Ask the roomserver to perform the join. + rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes) + if joinRes.Error != nil { + return joinRes.Error.JSONResponse() } - return r.joinRoomUsingServers(roomID, servers) - -} - -// joinRoomByAlias joins a room using a room alias. -func (r joinRoomReq) joinRoomByAlias(roomAlias string) util.JSONResponse { - _, domain, err := gomatrixserverlib.SplitID('#', roomAlias) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Room alias must be in the form '#localpart:domain'"), - } - } - if domain == r.cfg.Matrix.ServerName { - queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias} - var queryRes roomserverAPI.GetRoomIDForAliasResponse - if err = r.aliasAPI.GetRoomIDForAlias(r.req.Context(), &queryReq, &queryRes); err != nil { - util.GetLogger(r.req.Context()).WithError(err).Error("r.aliasAPI.GetRoomIDForAlias failed") - return jsonerror.InternalServerError() - } - - if len(queryRes.RoomID) > 0 { - return r.joinRoomUsingServers(queryRes.RoomID, []gomatrixserverlib.ServerName{r.cfg.Matrix.ServerName}) - } - // If the response doesn't contain a non-empty string, return an error - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room alias " + roomAlias + " not found."), - } - } - // If the room isn't local, use federation to join - return r.joinRoomByRemoteAlias(domain, roomAlias) -} - -func (r joinRoomReq) joinRoomByRemoteAlias( - domain gomatrixserverlib.ServerName, roomAlias string, -) util.JSONResponse { - resp, err := r.federation.LookupRoomAlias(r.req.Context(), domain, roomAlias) - if err != nil { - switch x := err.(type) { - case gomatrix.HTTPError: - if x.Code == http.StatusNotFound { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room alias not found"), - } - } - } - util.GetLogger(r.req.Context()).WithError(err).Error("r.federation.LookupRoomAlias failed") - return jsonerror.InternalServerError() - } - - return r.joinRoomUsingServers(resp.RoomID, resp.Servers) -} - -func (r joinRoomReq) writeToBuilder(eb *gomatrixserverlib.EventBuilder, roomID string) error { - eb.Type = "m.room.member" - - err := eb.SetContent(r.content) - if err != nil { - return err - } - - err = eb.SetUnsigned(struct{}{}) - if err != nil { - return err - } - - eb.Sender = r.userID - eb.StateKey = &r.userID - eb.RoomID = roomID - eb.Redacts = "" - - return nil -} - -func (r joinRoomReq) joinRoomUsingServers( - roomID string, servers []gomatrixserverlib.ServerName, -) util.JSONResponse { - var eb gomatrixserverlib.EventBuilder - err := r.writeToBuilder(&eb, roomID) - if err != nil { - util.GetLogger(r.req.Context()).WithError(err).Error("r.writeToBuilder failed") - return jsonerror.InternalServerError() - } - - queryRes := roomserverAPI.QueryLatestEventsAndStateResponse{} - event, err := common.BuildEvent(r.req.Context(), &eb, r.cfg, r.evTime, r.queryAPI, &queryRes) - if err == nil { - // If we have successfully built an event at this point then we can - // assert that the room is a local room, as BuildEvent was able to - // add prev_events etc successfully. - if _, err = r.producer.SendEvents( - r.req.Context(), - []gomatrixserverlib.HeaderedEvent{ - (*event).Headered(queryRes.RoomVersion), - }, - r.cfg.Matrix.ServerName, - nil, - ); err != nil { - util.GetLogger(r.req.Context()).WithError(err).Error("r.producer.SendEvents failed") - return jsonerror.InternalServerError() - } - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct { - RoomID string `json:"room_id"` - }{roomID}, - } - } - - // Otherwise, if we've reached here, then we haven't been able to populate - // prev_events etc for the room, therefore the room is probably federated. - - // TODO: This needs to be re-thought, as in the case of an invite, the room - // will exist in the database in roomserver_rooms but won't have any state - // events, therefore this below check fails. - if err != common.ErrRoomNoExists { - util.GetLogger(r.req.Context()).WithError(err).Error("common.BuildEvent failed") - return jsonerror.InternalServerError() - } - - if len(servers) == 0 { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("No candidate servers found for room"), - } - } - - var lastErr error - for _, server := range servers { - var response *util.JSONResponse - response, lastErr = r.joinRoomUsingServer(roomID, server) - if lastErr != nil { - // There was a problem talking to one of the servers. - util.GetLogger(r.req.Context()).WithError(lastErr).WithField("server", server).Warn("Failed to join room using server") - // Try the next server. - if r.req.Context().Err() != nil { - // The request context has expired so don't bother trying any - // more servers - they will immediately fail due to the expired - // context. - break - } else { - // The request context hasn't expired yet so try the next server. - continue - } - } - return *response - } - - // Every server we tried to join through resulted in an error. - // We return the error from the last server. - - // TODO: Generate the correct HTTP status code for all different - // kinds of errors that could have happened. - // The possible errors include: - // 1) We can't connect to the remote servers. - // 2) None of the servers we could connect to think we are allowed - // to join the room. - // 3) The remote server returned something invalid. - // 4) We couldn't fetch the public keys needed to verify the - // signatures on the state events. - // 5) ... - util.GetLogger(r.req.Context()).WithError(lastErr).Error("failed to join through any server") - return jsonerror.InternalServerError() -} - -// joinRoomUsingServer tries to join a remote room using a given matrix server. -// If there was a failure communicating with the server or the response from the -// server was invalid this returns an error. -// Otherwise this returns a JSONResponse. -func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) { - // Ask the room server for information about room versions. - var request api.QueryRoomVersionCapabilitiesRequest - var response api.QueryRoomVersionCapabilitiesResponse - if err := r.queryAPI.QueryRoomVersionCapabilities(r.req.Context(), &request, &response); err != nil { - return nil, err - } - var supportedVersions []gomatrixserverlib.RoomVersion - for version := range response.AvailableRoomVersions { - supportedVersions = append(supportedVersions, version) - } - respMakeJoin, err := r.federation.MakeJoin(r.req.Context(), server, roomID, r.userID, supportedVersions) - if err != nil { - // TODO: Check if the user was not allowed to join the room. - return nil, fmt.Errorf("r.federation.MakeJoin: %w", err) - } - - // Set all the fields to be what they should be, this should be a no-op - // but it's possible that the remote server returned us something "odd" - err = r.writeToBuilder(&respMakeJoin.JoinEvent, roomID) - if err != nil { - return nil, fmt.Errorf("r.writeToBuilder: %w", err) - } - - if respMakeJoin.RoomVersion == "" { - respMakeJoin.RoomVersion = gomatrixserverlib.RoomVersionV1 - } - if _, err = respMakeJoin.RoomVersion.EventFormat(); err != nil { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion( - fmt.Sprintf("Room version '%s' is not supported", respMakeJoin.RoomVersion), - ), - }, nil - } - - event, err := respMakeJoin.JoinEvent.Build( - r.evTime, r.cfg.Matrix.ServerName, r.cfg.Matrix.KeyID, - r.cfg.Matrix.PrivateKey, respMakeJoin.RoomVersion, - ) - if err != nil { - return nil, fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err) - } - - respSendJoin, err := r.federation.SendJoin(r.req.Context(), server, event, respMakeJoin.RoomVersion) - if err != nil { - return nil, fmt.Errorf("r.federation.SendJoin: %w", err) - } - - if err = r.checkSendJoinResponse(event, server, respMakeJoin, respSendJoin); err != nil { - return nil, err - } - - util.GetLogger(r.req.Context()).WithFields(logrus.Fields{ - "room_id": roomID, - "num_auth_events": len(respSendJoin.AuthEvents), - "num_state_events": len(respSendJoin.StateEvents), - }).Info("Room join signature and auth verification passed") - - if err = r.producer.SendEventWithState( - r.req.Context(), - gomatrixserverlib.RespState(respSendJoin.RespState), - event.Headered(respMakeJoin.RoomVersion), - ); err != nil { - util.GetLogger(r.req.Context()).WithError(err).Error("r.producer.SendEventWithState") - } - - return &util.JSONResponse{ + return util.JSONResponse{ Code: http.StatusOK, - // TODO: Put the response struct somewhere common. + // TODO: Put the response struct somewhere internal. JSON: struct { RoomID string `json:"room_id"` - }{roomID}, - }, nil -} - -// checkSendJoinResponse checks that all of the signatures are correct -// and that the join is allowed by the supplied state. -func (r joinRoomReq) checkSendJoinResponse( - event gomatrixserverlib.Event, - server gomatrixserverlib.ServerName, - respMakeJoin gomatrixserverlib.RespMakeJoin, - respSendJoin gomatrixserverlib.RespSendJoin, -) error { - // A list of events that we have retried, if they were not included in - // the auth events supplied in the send_join. - retries := map[string]bool{} - -retryCheck: - // TODO: Can we expand Check here to return a list of missing auth - // events rather than failing one at a time? - if err := respSendJoin.Check(r.req.Context(), r.keyRing, event); err != nil { - switch e := err.(type) { - case gomatrixserverlib.MissingAuthEventError: - // Check that we haven't already retried for this event, prevents - // us from ending up in endless loops - if !retries[e.AuthEventID] { - // Ask the server that we're talking to right now for the event - tx, txerr := r.federation.GetEvent(r.req.Context(), server, e.AuthEventID) - if txerr != nil { - return fmt.Errorf("r.federation.GetEvent: %w", txerr) - } - // For each event returned, add it to the auth events. - for _, pdu := range tx.PDUs { - ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, respMakeJoin.RoomVersion) - if everr != nil { - return fmt.Errorf("gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr) - } - respSendJoin.AuthEvents = append(respSendJoin.AuthEvents, ev) - } - // Mark the event as retried and then give the check another go. - retries[e.AuthEventID] = true - goto retryCheck - } - return fmt.Errorf("respSendJoin (after retries): %w", e) - default: - return fmt.Errorf("respSendJoin: %w", err) - } + }{joinRes.RoomID}, } - return nil } diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go new file mode 100644 index 000000000..38cef118e --- /dev/null +++ b/clientapi/routing/leaveroom.go @@ -0,0 +1,51 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +func LeaveRoomByID( + req *http.Request, + device *api.Device, + rsAPI roomserverAPI.RoomserverInternalAPI, + roomID string, +) util.JSONResponse { + // Prepare to ask the roomserver to perform the room join. + leaveReq := roomserverAPI.PerformLeaveRequest{ + RoomID: roomID, + UserID: device.UserID, + } + leaveRes := roomserverAPI.PerformLeaveResponse{} + + // Ask the roomserver to perform the leave. + if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Unknown(err.Error()), + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 21b947200..dc0180da6 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -20,13 +20,13 @@ import ( "context" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -47,6 +47,7 @@ type loginIdentifier struct { type passwordRequest struct { Identifier loginIdentifier `json:"identifier"` + User string `json:"user"` // deprecated in favour of identifier Password string `json:"password"` // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") // Thus a pointer is needed to differentiate between the two @@ -80,7 +81,8 @@ func Login( } } else if req.Method == http.MethodPost { var r passwordRequest - var acc *authtypes.Account + var acc *api.Account + var errJSON *util.JSONResponse resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr @@ -93,30 +95,22 @@ func Login( JSON: jsonerror.BadJSON("'user' must be supplied."), } } - - util.GetLogger(req.Context()).WithField("user", r.Identifier.User).Info("Processing login request") - - localpart, err := userutil.ParseUsernameParam(r.Identifier.User, &cfg.Matrix.ServerName) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(err.Error()), - } - } - - acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password) - if err != nil { - // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows - // but that would leak the existence of the user. - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"), - } + acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.Identifier.User) + if errJSON != nil { + return *errJSON } default: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("login identifier '" + r.Identifier.Type + "' not supported"), + // TODO: The below behaviour is deprecated but without it Riot iOS won't log in + if r.User != "" { + acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.User) + if errJSON != nil { + return *errJSON + } + } else { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("login identifier '" + r.Identifier.Type + "' not supported"), + } } } @@ -155,11 +149,40 @@ func getDevice( ctx context.Context, r passwordRequest, deviceDB devices.Database, - acc *authtypes.Account, + acc *api.Account, token string, -) (dev *authtypes.Device, err error) { +) (dev *api.Device, err error) { dev, err = deviceDB.CreateDevice( ctx, acc.Localpart, r.DeviceID, token, r.InitialDisplayName, ) return } + +func (r *passwordRequest) processUsernamePasswordLoginRequest( + req *http.Request, accountDB accounts.Database, + cfg *config.Dendrite, username string, +) (acc *api.Account, errJSON *util.JSONResponse) { + util.GetLogger(req.Context()).WithField("user", username).Info("Processing login request") + + localpart, err := userutil.ParseUsernameParam(username, &cfg.Matrix.ServerName) + if err != nil { + errJSON = &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(err.Error()), + } + return + } + + acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password) + if err != nil { + // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows + // but that would leak the existence of the user. + errJSON = &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"), + } + return + } + + return +} diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index 26b7f117e..3ce47169e 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -17,16 +17,16 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // Logout handles POST /logout func Logout( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -47,7 +47,7 @@ func Logout( // LogoutAll handles POST /logout/all func LogoutAll( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 9f386b718..aff1730c5 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -22,15 +22,15 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/threepid" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -40,15 +40,16 @@ var errMissingUserID = errors.New("'user_id' must be supplied") // SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite) // by building a m.room.member event then sending it to the room server +// TODO: Can we improve the cyclo count here? Separate code paths for invites? +// nolint:gocyclo func SendMembership( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *userapi.Device, roomID string, membership string, cfg *config.Dendrite, - queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, - producer *producers.RoomserverProducer, + rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := queryAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil { + if err := rsAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.UnsupportedRoomVersion(err.Error()), @@ -69,7 +70,7 @@ func SendMembership( } inviteStored, jsonErrResp := checkAndProcessThreepid( - req, device, &body, cfg, queryAPI, accountDB, producer, + req, device, &body, cfg, rsAPI, accountDB, membership, roomID, evTime, ) if jsonErrResp != nil { @@ -87,14 +88,15 @@ func SendMembership( } event, err := buildMembershipEvent( - req.Context(), body, accountDB, device, membership, roomID, cfg, evTime, queryAPI, asAPI, + req.Context(), body, accountDB, device, membership, + roomID, false, cfg, evTime, rsAPI, asAPI, ) if err == errMissingUserID { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error()), } - } else if err == common.ErrRoomNoExists { + } else if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound(err.Error()), @@ -104,23 +106,39 @@ func SendMembership( return jsonerror.InternalServerError() } - if _, err := producer.SendEvents( - req.Context(), - []gomatrixserverlib.HeaderedEvent{(*event).Headered(verRes.RoomVersion)}, - cfg.Matrix.ServerName, - nil, - ); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed") - return jsonerror.InternalServerError() - } - var returnData interface{} = struct{}{} - // The join membership requires the room id to be sent in the response - if membership == gomatrixserverlib.Join { + switch membership { + case gomatrixserverlib.Invite: + // Invites need to be handled specially + perr := roomserverAPI.SendInvite( + req.Context(), rsAPI, + event.Headered(verRes.RoomVersion), + nil, // ask the roomserver to draw up invite room state for us + cfg.Matrix.ServerName, + nil, + ) + if perr != nil { + util.GetLogger(req.Context()).WithError(perr).Error("producer.SendInvite failed") + return perr.JSONResponse() + } + case gomatrixserverlib.Join: + // The join membership requires the room id to be sent in the response returnData = struct { RoomID string `json:"room_id"` }{roomID} + fallthrough + default: + _, err = roomserverAPI.SendEvents( + req.Context(), rsAPI, + []gomatrixserverlib.HeaderedEvent{event.Headered(verRes.RoomVersion)}, + cfg.Matrix.ServerName, + nil, + ) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") + return jsonerror.InternalServerError() + } } return util.JSONResponse{ @@ -132,10 +150,10 @@ func SendMembership( func buildMembershipEvent( ctx context.Context, body threepid.MembershipRequest, accountDB accounts.Database, - device *authtypes.Device, - membership, roomID string, + device *userapi.Device, + membership, roomID string, isDirect bool, cfg *config.Dendrite, evTime time.Time, - queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) (*gomatrixserverlib.Event, error) { stateKey, reason, err := getMembershipStateKey(body, device, membership) if err != nil { @@ -164,13 +182,14 @@ func buildMembershipEvent( DisplayName: profile.DisplayName, AvatarURL: profile.AvatarURL, Reason: reason, + IsDirect: isDirect, } if err = builder.SetContent(content); err != nil { return nil, err } - return common.BuildEvent(ctx, &builder, cfg, evTime, queryAPI, nil) + return eventutil.BuildEvent(ctx, &builder, cfg, evTime, rsAPI, nil) } // loadProfile lookups the profile of a given user from the database and returns @@ -205,7 +224,7 @@ func loadProfile( // In the latter case, if there was an issue retrieving the user ID from the request body, // returns a JSONResponse with a corresponding error code and message. func getMembershipStateKey( - body threepid.MembershipRequest, device *authtypes.Device, membership string, + body threepid.MembershipRequest, device *userapi.Device, membership string, ) (stateKey string, reason string, err error) { if membership == gomatrixserverlib.Ban || membership == "unban" || membership == "kick" || membership == gomatrixserverlib.Invite { // If we're in this case, the state key is contained in the request body, @@ -227,18 +246,17 @@ func getMembershipStateKey( func checkAndProcessThreepid( req *http.Request, - device *authtypes.Device, + device *userapi.Device, body *threepid.MembershipRequest, cfg *config.Dendrite, - queryAPI roomserverAPI.RoomserverQueryAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, accountDB accounts.Database, - producer *producers.RoomserverProducer, membership, roomID string, evTime time.Time, ) (inviteStored bool, errRes *util.JSONResponse) { inviteStored, err := threepid.CheckAndProcessInvite( - req.Context(), device, body, cfg, queryAPI, accountDB, producer, + req.Context(), device, body, cfg, rsAPI, accountDB, membership, roomID, evTime, ) if err == threepid.ErrMissingParameter { @@ -251,12 +269,18 @@ func checkAndProcessThreepid( Code: http.StatusBadRequest, JSON: jsonerror.NotTrusted(body.IDServer), } - } else if err == common.ErrRoomNoExists { + } else if err == eventutil.ErrRoomNoExists { return inviteStored, &util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound(err.Error()), } - } else if err != nil { + } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + return inviteStored, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(e.Error()), + } + } + if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") er := jsonerror.InternalServerError() return inviteStored, &er diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index 0b846e5e3..1c9800b66 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -15,14 +15,15 @@ package routing import ( + "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -35,11 +36,21 @@ type getJoinedRoomsResponse struct { JoinedRooms []string `json:"joined_rooms"` } +// https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members +type getJoinedMembersResponse struct { + Joined map[string]joinedMember `json:"joined"` +} + +type joinedMember struct { + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` +} + // GetMemberships implements GET /rooms/{roomId}/members func GetMemberships( - req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool, + req *http.Request, device *userapi.Device, roomID string, joinedOnly bool, _ *config.Dendrite, - queryAPI api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { queryReq := api.QueryMembershipsForRoomRequest{ JoinedOnly: joinedOnly, @@ -47,8 +58,8 @@ func GetMemberships( Sender: device.UserID, } var queryRes api.QueryMembershipsForRoomResponse - if err := queryAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("queryAPI.QueryMembershipsForRoom failed") + if err := rsAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") return jsonerror.InternalServerError() } @@ -59,6 +70,22 @@ func GetMemberships( } } + if joinedOnly { + var res getJoinedMembersResponse + res.Joined = make(map[string]joinedMember) + for _, ev := range queryRes.JoinEvents { + var content joinedMember + if err := json.Unmarshal(ev.Content, &content); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content") + return jsonerror.InternalServerError() + } + res.Joined[ev.Sender] = content + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: res, + } + } return util.JSONResponse{ Code: http.StatusOK, JSON: getMembershipResponse{queryRes.JoinEvents}, @@ -67,7 +94,7 @@ func GetMemberships( func GetJoinedRooms( req *http.Request, - device *authtypes.Device, + device *userapi.Device, accountsDB accounts.Database, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index a51c55ea5..7c2cd19bc 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -21,13 +21,13 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" @@ -43,7 +43,7 @@ func GetProfile( ) util.JSONResponse { profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) if err != nil { - if err == common.ErrProfileNoExists { + if err == eventutil.ErrProfileNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), @@ -56,7 +56,7 @@ func GetProfile( return util.JSONResponse{ Code: http.StatusOK, - JSON: common.ProfileResponse{ + JSON: eventutil.ProfileResponse{ AvatarURL: profile.AvatarURL, DisplayName: profile.DisplayName, }, @@ -71,7 +71,7 @@ func GetAvatarURL( ) util.JSONResponse { profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) if err != nil { - if err == common.ErrProfileNoExists { + if err == eventutil.ErrProfileNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), @@ -84,17 +84,17 @@ func GetAvatarURL( return util.JSONResponse{ Code: http.StatusOK, - JSON: common.AvatarURL{ + JSON: eventutil.AvatarURL{ AvatarURL: profile.AvatarURL, }, } } // SetAvatarURL implements PUT /profile/{userID}/avatar_url +// nolint:gocyclo func SetAvatarURL( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, - userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, - rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, + req *http.Request, accountDB accounts.Database, device *userapi.Device, + userID string, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ @@ -103,9 +103,7 @@ func SetAvatarURL( } } - changedKey := "avatar_url" - - var r common.AvatarURL + var r eventutil.AvatarURL if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } @@ -154,20 +152,22 @@ func SetAvatarURL( } events, err := buildMembershipEvents( - req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI, + req.Context(), memberships, newProfile, userID, cfg, evTime, rsAPI, ) - if err != nil { + switch e := err.(type) { + case nil: + case gomatrixserverlib.BadJSONError: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(e.Error()), + } + default: util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") return jsonerror.InternalServerError() } - if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("rsProducer.SendEvents failed") - return jsonerror.InternalServerError() - } - - if err := producer.SendUpdate(userID, changedKey, oldProfile.AvatarURL, r.AvatarURL); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("producer.SendUpdate failed") + if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -185,7 +185,7 @@ func GetDisplayName( ) util.JSONResponse { profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) if err != nil { - if err == common.ErrProfileNoExists { + if err == eventutil.ErrProfileNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), @@ -198,17 +198,17 @@ func GetDisplayName( return util.JSONResponse{ Code: http.StatusOK, - JSON: common.DisplayName{ + JSON: eventutil.DisplayName{ DisplayName: profile.DisplayName, }, } } // SetDisplayName implements PUT /profile/{userID}/displayname +// nolint:gocyclo func SetDisplayName( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, - userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, - rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, + req *http.Request, accountDB accounts.Database, device *userapi.Device, + userID string, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ @@ -217,9 +217,7 @@ func SetDisplayName( } } - changedKey := "displayname" - - var r common.DisplayName + var r eventutil.DisplayName if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } @@ -268,20 +266,22 @@ func SetDisplayName( } events, err := buildMembershipEvents( - req.Context(), memberships, newProfile, userID, cfg, evTime, queryAPI, + req.Context(), memberships, newProfile, userID, cfg, evTime, rsAPI, ) - if err != nil { + switch e := err.(type) { + case nil: + case gomatrixserverlib.BadJSONError: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(e.Error()), + } + default: util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") return jsonerror.InternalServerError() } - if _, err := rsProducer.SendEvents(req.Context(), events, cfg.Matrix.ServerName, nil); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("rsProducer.SendEvents failed") - return jsonerror.InternalServerError() - } - - if err := producer.SendUpdate(userID, changedKey, oldProfile.DisplayName, r.DisplayName); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("producer.SendUpdate failed") + if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -294,7 +294,7 @@ func SetDisplayName( // getProfile gets the full profile of a user by querying the database or a // remote homeserver. // Returns an error when something goes wrong or specifically -// common.ErrProfileNoExists when the profile doesn't exist. +// eventutil.ErrProfileNoExists when the profile doesn't exist. func getProfile( ctx context.Context, accountDB accounts.Database, cfg *config.Dendrite, userID string, @@ -311,7 +311,7 @@ func getProfile( if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { - return nil, common.ErrProfileNoExists + return nil, eventutil.ErrProfileNoExists } } @@ -337,14 +337,14 @@ func buildMembershipEvents( ctx context.Context, memberships []authtypes.Membership, newProfile authtypes.Profile, userID string, cfg *config.Dendrite, - evTime time.Time, queryAPI api.RoomserverQueryAPI, + evTime time.Time, rsAPI api.RoomserverInternalAPI, ) ([]gomatrixserverlib.HeaderedEvent, error) { evs := []gomatrixserverlib.HeaderedEvent{} for _, membership := range memberships { verReq := api.QueryRoomVersionForRoomRequest{RoomID: membership.RoomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := queryAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + if err := rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { return []gomatrixserverlib.HeaderedEvent{}, err } @@ -366,7 +366,7 @@ func buildMembershipEvents( return nil, err } - event, err := common.BuildEvent(ctx, &builder, cfg, evTime, queryAPI, nil) + event, err := eventutil.BuildEvent(ctx, &builder, cfg, evTime, rsAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index b67e68e19..69ebdfd70 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -32,16 +32,16 @@ import ( "sync" "time" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/common" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/util" @@ -136,7 +136,7 @@ type registerRequest struct { DeviceID *string `json:"device_id"` // Prevent this user from logging in - InhibitLogin common.WeakBoolean `json:"inhibit_login"` + InhibitLogin eventutil.WeakBoolean `json:"inhibit_login"` // Application Services place Type in the root of their registration // request, whereas clients place it in the authDict struct. @@ -440,8 +440,8 @@ func validateApplicationService( // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register func Register( req *http.Request, + userAPI userapi.UserInternalAPI, accountDB accounts.Database, - deviceDB devices.Database, cfg *config.Dendrite, ) util.JSONResponse { var r registerRequest @@ -450,7 +450,7 @@ func Register( return *resErr } if req.URL.Query().Get("kind") == "guest" { - return handleGuestRegistration(req, r, cfg, accountDB, deviceDB) + return handleGuestRegistration(req, r, cfg, userAPI) } // Retrieve or generate the sessionID @@ -506,17 +506,19 @@ func Register( "session_id": r.Auth.Session, }).Info("Processing registration request") - return handleRegistrationFlow(req, r, sessionID, cfg, accountDB, deviceDB) + return handleRegistrationFlow(req, r, sessionID, cfg, userAPI) } func handleGuestRegistration( req *http.Request, r registerRequest, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { - acc, err := accountDB.CreateGuestAccount(req.Context()) + var res userapi.PerformAccountCreationResponse + err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeGuest, + }, &res) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -525,8 +527,8 @@ func handleGuestRegistration( } token, err := tokens.GenerateLoginToken(tokens.TokenOptions{ ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(), - ServerName: string(acc.ServerName), - UserID: acc.UserID, + ServerName: string(res.Account.ServerName), + UserID: res.Account.UserID, }) if err != nil { @@ -536,7 +538,12 @@ func handleGuestRegistration( } } //we don't allow guests to specify their own device_id - dev, err := deviceDB.CreateDevice(req.Context(), acc.Localpart, nil, token, r.InitialDisplayName) + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ + Localpart: res.Account.Localpart, + DeviceDisplayName: r.InitialDisplayName, + AccessToken: token, + }, &devRes) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -546,10 +553,10 @@ func handleGuestRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: dev.UserID, - AccessToken: dev.AccessToken, - HomeServer: acc.ServerName, - DeviceID: dev.ID, + UserID: devRes.Device.UserID, + AccessToken: devRes.Device.AccessToken, + HomeServer: res.Account.ServerName, + DeviceID: devRes.Device.ID, }, } } @@ -562,8 +569,7 @@ func handleRegistrationFlow( r registerRequest, sessionID string, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { // TODO: Shared secret registration (create new user scripts) // TODO: Enable registration config flag @@ -614,7 +620,7 @@ func handleRegistrationFlow( // by whether the request contains an access token. if err == nil { return handleApplicationServiceRegistration( - accessToken, err, req, r, cfg, accountDB, deviceDB, + accessToken, err, req, r, cfg, userAPI, ) } @@ -625,7 +631,7 @@ func handleRegistrationFlow( // don't need a condition on that call since the registration is clearly // stated as being AS-related. return handleApplicationServiceRegistration( - accessToken, err, req, r, cfg, accountDB, deviceDB, + accessToken, err, req, r, cfg, userAPI, ) case authtypes.LoginTypeDummy: @@ -644,7 +650,7 @@ func handleRegistrationFlow( // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), - req, r, sessionID, cfg, accountDB, deviceDB) + req, r, sessionID, cfg, userAPI) } // handleApplicationServiceRegistration handles the registration of an @@ -661,8 +667,7 @@ func handleApplicationServiceRegistration( req *http.Request, r registerRequest, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { // Check if we previously had issues extracting the access token from the // request. @@ -686,7 +691,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), accountDB, deviceDB, r.Username, "", appserviceID, + req.Context(), userAPI, r.Username, "", appserviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -700,13 +705,12 @@ func checkAndCompleteFlow( r registerRequest, sessionID string, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), accountDB, deviceDB, r.Username, r.Password, "", + req.Context(), userAPI, r.Username, r.Password, "", r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -723,8 +727,7 @@ func checkAndCompleteFlow( // LegacyRegister process register requests from the legacy v1 API func LegacyRegister( req *http.Request, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, cfg *config.Dendrite, ) util.JSONResponse { var r legacyRegisterRequest @@ -759,10 +762,10 @@ func LegacyRegister( return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") } - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) + return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) + return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) default: return util.JSONResponse{ Code: http.StatusNotImplemented, @@ -808,10 +811,9 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u // not all func completeRegistration( ctx context.Context, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, username, password, appserviceID string, - inhibitLogin common.WeakBoolean, + inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, ) util.JSONResponse { if username == "" { @@ -828,17 +830,25 @@ func completeRegistration( } } - acc, err := accountDB.CreateAccount(ctx, username, password, appserviceID) + var accRes userapi.PerformAccountCreationResponse + err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ + AppServiceID: appserviceID, + Localpart: username, + Password: password, + AccountType: userapi.AccountTypeUser, + OnConflict: userapi.ConflictAbort, + }, &accRes) if err != nil { + if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UserInUse("Desired user ID is already taken."), + } + } return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.Unknown("failed to create account: " + err.Error()), } - } else if acc == nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is already taken."), - } } // Increment prometheus counter for created users @@ -850,8 +860,8 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: userutil.MakeUserID(username, acc.ServerName), - HomeServer: acc.ServerName, + UserID: userutil.MakeUserID(username, accRes.Account.ServerName), + HomeServer: accRes.Account.ServerName, }, } } @@ -864,7 +874,13 @@ func completeRegistration( } } - dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName) + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + Localpart: username, + AccessToken: token, + DeviceDisplayName: displayName, + DeviceID: deviceID, + }, &devRes) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -875,10 +891,10 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: dev.UserID, - AccessToken: dev.AccessToken, - HomeServer: acc.ServerName, - DeviceID: dev.ID, + UserID: devRes.Device.UserID, + AccessToken: devRes.Device.AccessToken, + HomeServer: accRes.Account.ServerName, + DeviceID: devRes.Device.ID, }, } } diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 6fcf0bc39..a44389f94 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -19,7 +19,7 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" ) var ( diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index 5c68668d0..c683cc949 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -20,28 +20,19 @@ import ( "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -// newTag creates and returns a new gomatrix.TagContent -func newTag() gomatrix.TagContent { - return gomatrix.TagContent{ - Tags: make(map[string]gomatrix.TagProperties), - } -} - // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags func GetTags( req *http.Request, - accountDB accounts.Database, - device *authtypes.Device, + userAPI api.UserInternalAPI, + device *api.Device, userID string, roomID string, syncProducer *producers.SyncAPIProducer, @@ -54,22 +45,15 @@ func GetTags( } } - _, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - if data == nil { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - return util.JSONResponse{ Code: http.StatusOK, - JSON: data.Content, + JSON: tagContent, } } @@ -78,8 +62,8 @@ func GetTags( // the tag to the "map" and saving the new "map" to the DB func PutTag( req *http.Request, - accountDB accounts.Database, - device *authtypes.Device, + userAPI api.UserInternalAPI, + device *api.Device, userID string, roomID string, tag string, @@ -98,34 +82,25 @@ func PutTag( return *reqErr } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - var tagContent gomatrix.TagContent - if data != nil { - if err = json.Unmarshal(data.Content, &tagContent); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - } else { - tagContent = newTag() + if tagContent.Tags == nil { + tagContent.Tags = make(map[string]gomatrix.TagProperties) } tagContent.Tags[tag] = properties - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -138,8 +113,8 @@ func PutTag( // the "map" and then saving the new "map" in the DB func DeleteTag( req *http.Request, - accountDB accounts.Database, - device *authtypes.Device, + userAPI api.UserInternalAPI, + device *api.Device, userID string, roomID string, tag string, @@ -153,28 +128,12 @@ func DeleteTag( } } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - // If there are no tags in the database, exit - if data == nil { - // Spec only defines 200 responses for this endpoint so we don't return anything else. - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - - var tagContent gomatrix.TagContent - err = json.Unmarshal(data.Content, &tagContent) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - // Check whether the tag to be deleted exists if _, ok := tagContent.Tags[tag]; ok { delete(tagContent.Tags, tag) @@ -185,18 +144,16 @@ func DeleteTag( JSON: struct{}{}, } } - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + // TODO: user API should do this since it's account data + if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -210,32 +167,46 @@ func obtainSavedTags( req *http.Request, userID string, roomID string, - accountDB accounts.Database, -) (string, *gomatrixserverlib.ClientEvent, error) { - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return "", nil, err + userAPI api.UserInternalAPI, +) (tags gomatrix.TagContent, err error) { + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", } - - data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, "m.tag", - ) - - return localpart, data, err + dataRes := api.QueryAccountDataResponse{} + err = userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes) + if err != nil { + return + } + data, ok := dataRes.RoomAccountData[roomID]["m.tag"] + if !ok { + return + } + if err = json.Unmarshal(data, &tags); err != nil { + return + } + return tags, nil } // saveTagData saves the provided tag data into the database func saveTagData( req *http.Request, - localpart string, + userID string, roomID string, - accountDB accounts.Database, + userAPI api.UserInternalAPI, Tag gomatrix.TagContent, ) error { newTagData, err := json.Marshal(Tag) if err != nil { return err } - - return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData)) + dataReq := api.InputAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", + AccountData: json.RawMessage(newTagData), + } + dataRes := api.InputAccountDataResponse{} + return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 5dc6d7db9..825ac50f2 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -21,24 +21,24 @@ import ( "github.com/gorilla/mux" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/dendrite/common/transactions" + eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/transactions" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -const pathPrefixV1 = "/_matrix/client/api/v1" -const pathPrefixR0 = "/_matrix/client/r0" -const pathPrefixUnstable = "/_matrix/client/unstable" +const pathPrefixV1 = "/client/api/v1" +const pathPrefixR0 = "/client/r0" +const pathPrefixUnstable = "/client/unstable" // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client // to clients which need to make outbound HTTP requests. @@ -47,24 +47,21 @@ const pathPrefixUnstable = "/_matrix/client/unstable" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, cfg *config.Dendrite, - producer *producers.RoomserverProducer, - queryAPI roomserverAPI.RoomserverQueryAPI, - aliasAPI roomserverAPI.RoomserverAliasAPI, + publicAPIMux *mux.Router, cfg *config.Dendrite, + eduAPI eduServerAPI.EDUServerInputAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, accountDB accounts.Database, deviceDB devices.Database, + userAPI api.UserInternalAPI, federation *gomatrixserverlib.FederationClient, - keyRing gomatrixserverlib.KeyRing, - userUpdateProducer *producers.UserUpdateProducer, syncProducer *producers.SyncAPIProducer, - eduProducer *producers.EDUServerProducer, transactionsCache *transactions.Cache, - federationSender federationSenderAPI.FederationSenderQueryAPI, + federationSender federationSenderAPI.FederationSenderInternalAPI, ) { - apiMux.Handle("/_matrix/client/versions", - common.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse { + publicAPIMux.Handle("/client/versions", + httputil.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse { return util.JSONResponse{ Code: http.StatusOK, JSON: struct { @@ -79,104 +76,115 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() - v1mux := apiMux.PathPrefix(pathPrefixV1).Subrouter() - unstableMux := apiMux.PathPrefix(pathPrefixUnstable).Subrouter() - - authData := auth.Data{ - AccountDB: accountDB, - DeviceDB: deviceDB, - AppServices: cfg.Derived.ApplicationServices, - } + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() + v1mux := publicAPIMux.PathPrefix(pathPrefixV1).Subrouter() + unstableMux := publicAPIMux.PathPrefix(pathPrefixUnstable).Subrouter() r0mux.Handle("/createRoom", - common.MakeAuthAPI("createRoom", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - return CreateRoom(req, device, cfg, producer, accountDB, aliasAPI, asAPI) + httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/join/{roomIDOrAlias}", - common.MakeAuthAPI(gomatrixserverlib.Join, authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } return JoinRoomByIDOrAlias( - req, device, vars["roomIDOrAlias"], cfg, federation, producer, queryAPI, aliasAPI, keyRing, accountDB, + req, device, rsAPI, accountDB, vars["roomIDOrAlias"], ) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/joined_rooms", - common.MakeAuthAPI("joined_rooms", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return GetJoinedRooms(req, device, accountDB) }), ).Methods(http.MethodGet, http.MethodOptions) - - r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}", - common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + r0mux.Handle("/rooms/{roomID}/leave", + httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return SendMembership(req, accountDB, device, vars["roomID"], vars["membership"], cfg, queryAPI, asAPI, producer) + return LeaveRoomByID( + req, device, rsAPI, vars["roomID"], + ) + }), + ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|invite)}", + httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return SendMembership(req, accountDB, device, vars["roomID"], vars["membership"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/send/{eventType}", - common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, queryAPI, producer, nil) + return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", - common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } txnID := vars["txnID"] return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID, - nil, cfg, queryAPI, producer, transactionsCache) + nil, cfg, rsAPI, transactionsCache) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/event/{eventID}", - common.MakeAuthAPI("rooms_get_event", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, queryAPI, federation, keyRing) + return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, federation) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return OnIncomingStateRequest(req.Context(), queryAPI, vars["roomID"]) + return OnIncomingStateRequest(req.Context(), rsAPI, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type}", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + r0mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return OnIncomingStateTypeRequest(req.Context(), queryAPI, vars["roomID"], vars["type"], "") + // If there's a trailing slash, remove it + eventType := vars["type"] + if strings.HasSuffix(eventType, "/") { + eventType = eventType[:len(eventType)-1] + } + eventFormat := req.URL.Query().Get("format") == "event" + return OnIncomingStateTypeRequest(req.Context(), rsAPI, vars["roomID"], eventType, "", eventFormat) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return OnIncomingStateTypeRequest(req.Context(), queryAPI, vars["roomID"], vars["type"], vars["stateKey"]) + eventFormat := req.URL.Query().Get("format") == "event" + return OnIncomingStateTypeRequest(req.Context(), rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) })).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", - common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -186,87 +194,112 @@ func Setup( if strings.HasSuffix(eventType, "/") { eventType = eventType[:len(eventType)-1] } - return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, queryAPI, producer, nil) + return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", - common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } stateKey := vars["stateKey"] - return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, queryAPI, producer, nil) + return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - return Register(req, accountDB, deviceDB, cfg) + r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + return Register(req, userAPI, accountDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) - v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - return LegacyRegister(req, accountDB, deviceDB, cfg) + v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + return LegacyRegister(req, userAPI, cfg) })).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { return RegisterAvailable(req, cfg, accountDB) })).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/directory/room/{roomAlias}", - common.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return DirectoryRoom(req, vars["roomAlias"], federation, cfg, aliasAPI, federationSender) + return DirectoryRoom(req, vars["roomAlias"], federation, cfg, rsAPI, federationSender) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/directory/room/{roomAlias}", - common.MakeAuthAPI("directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return SetLocalAlias(req, device, vars["roomAlias"], cfg, aliasAPI) + return SetLocalAlias(req, device, vars["roomAlias"], cfg, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/directory/room/{roomAlias}", - common.MakeAuthAPI("directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return RemoveLocalAlias(req, device, vars["roomAlias"], aliasAPI) + return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI) }), ).Methods(http.MethodDelete, http.MethodOptions) r0mux.Handle("/logout", - common.MakeAuthAPI("logout", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return Logout(req, deviceDB, device) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/logout/all", - common.MakeAuthAPI("logout", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return LogoutAll(req, deviceDB, device) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/typing/{userID}", - common.MakeAuthAPI("rooms_typing", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduProducer) + return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI) + }), + ).Methods(http.MethodPut, http.MethodOptions) + + r0mux.Handle("/sendToDevice/{eventType}/{txnID}", + httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + txnID := vars["txnID"] + return SendToDevice(req, device, eduAPI, transactionsCache, vars["eventType"], &txnID) + }), + ).Methods(http.MethodPut, http.MethodOptions) + + // This is only here because sytest refers to /unstable for this endpoint + // rather than r0. It's an exact duplicate of the above handler. + // TODO: Remove this if/when sytest is fixed! + unstableMux.Handle("/sendToDevice/{eventType}/{txnID}", + httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + txnID := vars["txnID"] + return SendToDevice(req, device, eduAPI, transactionsCache, vars["eventType"], &txnID) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/account/whoami", - common.MakeAuthAPI("whoami", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return Whoami(req, device) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -274,20 +307,20 @@ func Setup( // Stub endpoints required by Riot r0mux.Handle("/login", - common.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { return Login(req, accountDB, deviceDB, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) r0mux.Handle("/auth/{authType}/fallback/web", - common.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) return AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) r0mux.Handle("/pushrules/", - common.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { // TODO: Implement push rules API res := json.RawMessage(`{ "global": { @@ -306,8 +339,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/user/{userId}/filter", - common.MakeAuthAPI("put_filter", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -316,8 +349,8 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/user/{userId}/filter/{filterId}", - common.MakeAuthAPI("get_filter", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -328,8 +361,8 @@ func Setup( // Riot user settings r0mux.Handle("/profile/{userID}", - common.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -338,8 +371,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/profile/{userID}/avatar_url", - common.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -348,20 +381,20 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/profile/{userID}/avatar_url", - common.MakeAuthAPI("profile_avatar_url", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI) + return SetAvatarURL(req, accountDB, device, vars["userID"], cfg, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method r0mux.Handle("/profile/{userID}/displayname", - common.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -370,44 +403,44 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/profile/{userID}/displayname", - common.MakeAuthAPI("profile_displayname", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI) + return SetDisplayName(req, accountDB, device, vars["userID"], cfg, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method r0mux.Handle("/account/3pid", - common.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return GetAssociated3PIDs(req, accountDB, device) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/account/3pid", - common.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return CheckAndSave3PIDAssociation(req, accountDB, device, cfg) }), ).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/account/3pid/delete", - common.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return Forget3PID(req, accountDB) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", - common.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { return RequestEmailToken(req, accountDB, cfg) }), ).Methods(http.MethodPost, http.MethodOptions) // Riot logs get flooded unless this is handled r0mux.Handle("/presence/{userID}/status", - common.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { // TODO: Set presence (probably the responsibility of a presence server not clientapi) return util.JSONResponse{ Code: http.StatusOK, @@ -417,13 +450,13 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/voip/turnServer", - common.MakeAuthAPI("turn_server", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return RequestTurnServer(req, device, cfg) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/thirdparty/protocols", - common.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { // TODO: Return the third party protcols return util.JSONResponse{ Code: http.StatusOK, @@ -433,7 +466,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/initialSync", - common.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse { // TODO: Allow people to peek into rooms. return util.JSONResponse{ Code: http.StatusForbidden, @@ -443,81 +476,81 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/user/{userID}/account_data/{type}", - common.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], "", vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", - common.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/user/{userID}/account_data/{type}", - common.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], "", vars["type"]) }), ).Methods(http.MethodGet) r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", - common.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"]) }), ).Methods(http.MethodGet) r0mux.Handle("/rooms/{roomID}/members", - common.MakeAuthAPI("rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return GetMemberships(req, device, vars["roomID"], false, cfg, queryAPI) + return GetMemberships(req, device, vars["roomID"], false, cfg, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/joined_members", - common.MakeAuthAPI("rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return GetMemberships(req, device, vars["roomID"], true, cfg, queryAPI) + return GetMemberships(req, device, vars["roomID"], true, cfg, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/read_markers", - common.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { // TODO: return the read_markers. return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/devices", - common.MakeAuthAPI("get_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return GetDevicesByLocalpart(req, deviceDB, device) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/devices/{deviceID}", - common.MakeAuthAPI("get_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -526,8 +559,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/devices/{deviceID}", - common.MakeAuthAPI("device_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -536,8 +569,8 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/devices/{deviceID}", - common.MakeAuthAPI("delete_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -546,14 +579,14 @@ func Setup( ).Methods(http.MethodDelete, http.MethodOptions) r0mux.Handle("/delete_devices", - common.MakeAuthAPI("delete_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return DeleteDevices(req, deviceDB, device) }), ).Methods(http.MethodPost, http.MethodOptions) // Stub implementations for sytest r0mux.Handle("/events", - common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, "start": "", @@ -563,7 +596,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/initialSync", - common.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", }} @@ -571,38 +604,38 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/user/{userId}/rooms/{roomId}/tags", - common.MakeAuthAPI("get_tags", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer) + return GetTags(req, userAPI, device, vars["userId"], vars["roomId"], syncProducer) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", - common.MakeAuthAPI("put_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", - common.MakeAuthAPI("delete_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodDelete, http.MethodOptions) r0mux.Handle("/capabilities", - common.MakeAuthAPI("capabilities", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - return GetCapabilities(req, queryAPI) + httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + return GetCapabilities(req, rsAPI) }), ).Methods(http.MethodGet) } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 5b2cd8ad4..d8936f750 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -17,14 +17,13 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/dendrite/common/transactions" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -42,16 +41,15 @@ type sendEventResponse struct { // /rooms/{roomID}/state/{eventType}/{stateKey} func SendEvent( req *http.Request, - device *authtypes.Device, + device *userapi.Device, roomID, eventType string, txnID, stateKey *string, cfg *config.Dendrite, - queryAPI api.RoomserverQueryAPI, - producer *producers.RoomserverProducer, + rsAPI api.RoomserverInternalAPI, txnCache *transactions.Cache, ) util.JSONResponse { verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := queryAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil { + if err := rsAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.UnsupportedRoomVersion(err.Error()), @@ -65,7 +63,7 @@ func SendEvent( } } - e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, queryAPI) + e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI) if resErr != nil { return *resErr } @@ -80,8 +78,8 @@ func SendEvent( // pass the new event to the roomserver and receive the correct event ID // event ID in case of duplicate transaction is discarded - eventID, err := producer.SendEvents( - req.Context(), + eventID, err := api.SendEvents( + req.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, @@ -89,7 +87,7 @@ func SendEvent( txnAndSessionID, ) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed") + util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } util.GetLogger(req.Context()).WithFields(logrus.Fields{ @@ -112,10 +110,10 @@ func SendEvent( func generateSendEvent( req *http.Request, - device *authtypes.Device, + device *userapi.Device, roomID, eventType string, stateKey *string, cfg *config.Dendrite, - queryAPI api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, ) (*gomatrixserverlib.Event, *util.JSONResponse) { // parse the incoming http request userID := device.UserID @@ -148,14 +146,19 @@ func generateSendEvent( } var queryRes api.QueryLatestEventsAndStateResponse - e, err := common.BuildEvent(req.Context(), &builder, cfg, evTime, queryAPI, &queryRes) - if err == common.ErrRoomNoExists { + e, err := eventutil.BuildEvent(req.Context(), &builder, cfg, evTime, rsAPI, &queryRes) + if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Room does not exist"), } + } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(e.Error()), + } } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("common.BuildEvent failed") + util.GetLogger(req.Context()).WithError(err).Error("eventutil.BuildEvent failed") resErr := jsonerror.InternalServerError() return nil, &resErr } diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go new file mode 100644 index 000000000..768e8e0e7 --- /dev/null +++ b/clientapi/routing/sendtodevice.go @@ -0,0 +1,70 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/transactions" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} +// sends the device events to the EDU Server +func SendToDevice( + req *http.Request, device *userapi.Device, + eduAPI api.EDUServerInputAPI, + txnCache *transactions.Cache, + eventType string, txnID *string, +) util.JSONResponse { + if txnID != nil { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + return *res + } + } + + var httpReq struct { + Messages map[string]map[string]json.RawMessage `json:"messages"` + } + resErr := httputil.UnmarshalJSONRequest(req, &httpReq) + if resErr != nil { + return *resErr + } + + for userID, byUser := range httpReq.Messages { + for deviceID, message := range byUser { + if err := api.SendToDevice( + req.Context(), eduAPI, device.UserID, userID, deviceID, eventType, message, + ); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("eduProducer.SendToDevice failed") + return jsonerror.InternalServerError() + } + } + } + + res := util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + + if txnID != nil { + txnCache.AddTransaction(device.AccessToken, *txnID, &res) + } + + return res +} diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index ffaa0e662..9b6a0b39b 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -16,12 +16,12 @@ import ( "database/sql" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/eduserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/util" ) @@ -33,9 +33,9 @@ type typingContentJSON struct { // SendTyping handles PUT /rooms/{roomID}/typing/{userID} // sends the typing events to client API typingProducer func SendTyping( - req *http.Request, device *authtypes.Device, roomID string, + req *http.Request, device *userapi.Device, roomID string, userID string, accountDB accounts.Database, - eduProducer *producers.EDUServerProducer, + eduAPI api.EDUServerInputAPI, ) util.JSONResponse { if device.UserID != userID { return util.JSONResponse{ @@ -69,8 +69,8 @@ func SendTyping( return *resErr } - if err = eduProducer.SendTyping( - req.Context(), userID, roomID, r.Typing, r.Timeout, + if err = api.SendTyping( + req.Context(), eduAPI, userID, roomID, r.Typing, r.Timeout, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("eduProducer.Send failed") return jsonerror.InternalServerError() diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index c243eec0f..2ec7a33f3 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -40,7 +40,7 @@ type stateEventInStateResp struct { // TODO: Check if the user is in the room. If not, check if the room's history // is publicly visible. Current behaviour is returning an empty array if the // user cannot see the room's history. -func OnIncomingStateRequest(ctx context.Context, queryAPI api.RoomserverQueryAPI, roomID string) util.JSONResponse { +func OnIncomingStateRequest(ctx context.Context, rsAPI api.RoomserverInternalAPI, roomID string) util.JSONResponse { // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) stateReq := api.QueryLatestEventsAndStateRequest{ @@ -48,7 +48,7 @@ func OnIncomingStateRequest(ctx context.Context, queryAPI api.RoomserverQueryAPI } stateRes := api.QueryLatestEventsAndStateResponse{} - if err := queryAPI.QueryLatestEventsAndState(ctx, &stateReq, &stateRes); err != nil { + if err := rsAPI.QueryLatestEventsAndState(ctx, &stateReq, &stateRes); err != nil { util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed") return jsonerror.InternalServerError() } @@ -98,7 +98,8 @@ func OnIncomingStateRequest(ctx context.Context, queryAPI api.RoomserverQueryAPI // /rooms/{roomID}/state/{type}/{statekey} request. It will look in current // state to see if there is an event with that type and state key, if there // is then (by default) we return the content, otherwise a 404. -func OnIncomingStateTypeRequest(ctx context.Context, queryAPI api.RoomserverQueryAPI, roomID string, evType, stateKey string) util.JSONResponse { +// If eventFormat=true, sends the whole event else just the content. +func OnIncomingStateTypeRequest(ctx context.Context, rsAPI api.RoomserverInternalAPI, roomID, evType, stateKey string, eventFormat bool) util.JSONResponse { // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) util.GetLogger(ctx).WithFields(log.Fields{ @@ -118,7 +119,7 @@ func OnIncomingStateTypeRequest(ctx context.Context, queryAPI api.RoomserverQuer } stateRes := api.QueryLatestEventsAndStateResponse{} - if err := queryAPI.QueryLatestEventsAndState(ctx, &stateReq, &stateRes); err != nil { + if err := rsAPI.QueryLatestEventsAndState(ctx, &stateReq, &stateRes); err != nil { util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed") return jsonerror.InternalServerError() } @@ -134,8 +135,15 @@ func OnIncomingStateTypeRequest(ctx context.Context, queryAPI api.RoomserverQuer ClientEvent: gomatrixserverlib.HeaderedToClientEvent(stateRes.StateEvents[0], gomatrixserverlib.FormatAll), } + var res interface{} + if eventFormat { + res = stateEvent + } else { + res = stateEvent.Content + } + return util.JSONResponse{ Code: http.StatusOK, - JSON: stateEvent.Content, + JSON: res, } } diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index fed9ae32e..e7aaadf54 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -18,11 +18,12 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -84,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *api.Device, cfg *config.Dendrite, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest @@ -148,7 +149,7 @@ func CheckAndSave3PIDAssociation( // GetAssociated3PIDs implements GET /account/3pid func GetAssociated3PIDs( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { diff --git a/clientapi/routing/voip.go b/clientapi/routing/voip.go index c1fd741c9..046e87811 100644 --- a/clientapi/routing/voip.go +++ b/clientapi/routing/voip.go @@ -22,16 +22,16 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" ) // RequestTurnServer implements: // GET /voip/turnServer -func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg *config.Dendrite) util.JSONResponse { +func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Dendrite) util.JSONResponse { turnConfig := cfg.TURN // TODO Guest Support diff --git a/clientapi/routing/whoami.go b/clientapi/routing/whoami.go index 840bcb5f2..26280f6cc 100644 --- a/clientapi/routing/whoami.go +++ b/clientapi/routing/whoami.go @@ -15,7 +15,7 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) @@ -26,7 +26,7 @@ type whoamiResponse struct { // Whoami implements `/account/whoami` which enables client to query their account user id. // https://matrix.org/docs/spec/client_server/r0.3.0.html#get-matrix-client-r0-account-whoami -func Whoami(req *http.Request, device *authtypes.Device) util.JSONResponse { +func Whoami(req *http.Request, device *api.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusOK, JSON: whoamiResponse{UserID: device.UserID}, diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index e34e91b56..c308cb1f4 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -25,11 +25,11 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" ) @@ -86,9 +86,9 @@ var ( // can be emitted. func CheckAndProcessInvite( ctx context.Context, - device *authtypes.Device, body *MembershipRequest, cfg *config.Dendrite, - queryAPI api.RoomserverQueryAPI, db accounts.Database, - producer *producers.RoomserverProducer, membership string, roomID string, + device *userapi.Device, body *MembershipRequest, cfg *config.Dendrite, + rsAPI api.RoomserverInternalAPI, db accounts.Database, + membership string, roomID string, evTime time.Time, ) (inviteStoredOnIDServer bool, err error) { if membership != gomatrixserverlib.Invite || (body.Address == "" && body.IDServer == "" && body.Medium == "") { @@ -112,7 +112,7 @@ func CheckAndProcessInvite( // "m.room.third_party_invite" have to be emitted from the data in // storeInviteRes. err = emit3PIDInviteEvent( - ctx, body, storeInviteRes, device, roomID, cfg, queryAPI, producer, evTime, + ctx, body, storeInviteRes, device, roomID, cfg, rsAPI, evTime, ) inviteStoredOnIDServer = err == nil @@ -137,7 +137,7 @@ func CheckAndProcessInvite( // Returns an error if a check or a request failed. func queryIDServer( ctx context.Context, - db accounts.Database, cfg *config.Dendrite, device *authtypes.Device, + db accounts.Database, cfg *config.Dendrite, device *userapi.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { if err = isTrusted(body.IDServer, cfg); err != nil { @@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( ctx context.Context, - db accounts.Database, cfg *config.Dendrite, device *authtypes.Device, + db accounts.Database, cfg *config.Dendrite, device *userapi.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name @@ -279,7 +279,7 @@ func queryIDServerPubKey(ctx context.Context, idServerName string, keyID string) } var pubKeyRes struct { - PublicKey gomatrixserverlib.Base64String `json:"public_key"` + PublicKey gomatrixserverlib.Base64Bytes `json:"public_key"` } if resp.StatusCode != http.StatusOK { @@ -330,8 +330,8 @@ func checkIDServerSignatures( func emit3PIDInviteEvent( ctx context.Context, body *MembershipRequest, res *idServerStoreInviteResponse, - device *authtypes.Device, roomID string, cfg *config.Dendrite, - queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer, + device *userapi.Device, roomID string, cfg *config.Dendrite, + rsAPI api.RoomserverInternalAPI, evTime time.Time, ) error { builder := &gomatrixserverlib.EventBuilder{ @@ -354,13 +354,13 @@ func emit3PIDInviteEvent( } queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := common.BuildEvent(ctx, builder, cfg, evTime, queryAPI, &queryRes) + event, err := eventutil.BuildEvent(ctx, builder, cfg, evTime, rsAPI, &queryRes) if err != nil { return err } - _, err = producer.SendEvents( - ctx, + _, err = api.SendEvents( + ctx, rsAPI, []gomatrixserverlib.HeaderedEvent{ (*event).Headered(queryRes.RoomVersion), }, diff --git a/clientapi/threepid/threepid.go b/clientapi/threepid/threepid.go index a7f26c295..bffe31adc 100644 --- a/clientapi/threepid/threepid.go +++ b/clientapi/threepid/threepid.go @@ -24,7 +24,7 @@ import ( "strconv" "strings" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" ) // EmailAssociationRequest represents the request defined at https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-register-email-requesttoken diff --git a/cmd/client-api-proxy/main.go b/cmd/client-api-proxy/main.go index 27991c109..979b0b042 100644 --- a/cmd/client-api-proxy/main.go +++ b/cmd/client-api-proxy/main.go @@ -75,7 +75,6 @@ func makeProxy(targetURL string) (*httputil.ReverseProxy, error) { // Pratically this means that any distinction between '%2F' and '/' // in the URL will be lost by the time it reaches the target. path := req.URL.Path - path = "api" + path log.WithFields(log.Fields{ "path": path, "url": targetURL, diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index fc51a5bb6..ff022ec3c 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -20,8 +20,8 @@ import ( "fmt" "os" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) @@ -63,22 +63,19 @@ func main() { serverName := gomatrixserverlib.ServerName(*serverNameStr) - accountDB, err := accounts.NewDatabase(*database, serverName) + accountDB, err := accounts.NewDatabase(*database, nil, serverName) if err != nil { fmt.Println(err.Error()) os.Exit(1) } - account, err := accountDB.CreateAccount(context.Background(), *username, *password, "") + _, err = accountDB.CreateAccount(context.Background(), *username, *password, "") if err != nil { fmt.Println(err.Error()) os.Exit(1) - } else if account == nil { - fmt.Println("Username already exists") - os.Exit(1) } - deviceDB, err := devices.NewDatabase(*database, serverName) + deviceDB, err := devices.NewDatabase(*database, nil, serverName) if err != nil { fmt.Println(err.Error()) os.Exit(1) diff --git a/cmd/create-room-events/main.go b/cmd/create-room-events/main.go index ebce953ce..afe974643 100644 --- a/cmd/create-room-events/main.go +++ b/cmd/create-room-events/main.go @@ -47,6 +47,7 @@ var ( userID = flag.String("user-id", "@userid:$SERVER_NAME", "The user ID to use as the event sender") messageCount = flag.Int("message-count", 10, "The number of m.room.messsage events to generate") format = flag.String("Format", "InputRoomEvent", "The output format to use for the messages: InputRoomEvent or Event") + ver = flag.String("version", string(gomatrixserverlib.RoomVersionV1), "Room version to generate events as") ) // By default we use a private key of 0. @@ -109,7 +110,7 @@ func buildAndOutput() gomatrixserverlib.EventReference { event, err := b.Build( now, name, key, privateKey, - gomatrixserverlib.RoomVersionV1, + gomatrixserverlib.RoomVersion(*ver), ) if err != nil { panic(err) @@ -127,7 +128,7 @@ func writeEvent(event gomatrixserverlib.Event) { if *format == "InputRoomEvent" { var ire api.InputRoomEvent ire.Kind = api.KindNew - ire.Event = event.Headered(gomatrixserverlib.RoomVersionV1) + ire.Event = event.Headered(gomatrixserverlib.RoomVersion(*ver)) authEventIDs := []string{} for _, ref := range b.AuthEvents.([]gomatrixserverlib.EventReference) { authEventIDs = append(authEventIDs, ref.EventID) diff --git a/cmd/dendrite-appservice-server/main.go b/cmd/dendrite-appservice-server/main.go index f203969f4..6719d0471 100644 --- a/cmd/dendrite-appservice-server/main.go +++ b/cmd/dendrite-appservice-server/main.go @@ -16,24 +16,19 @@ package main import ( "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/common/basecomponent" - "github.com/matrix-org/dendrite/common/transactions" + "github.com/matrix-org/dendrite/internal/setup" ) func main() { - cfg := basecomponent.ParseFlags() - base := basecomponent.NewBaseDendrite(cfg, "AppServiceAPI") + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "AppServiceAPI", true) defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() - federation := base.CreateFederationClient() - alias, _, query := base.CreateHTTPRoomserverAPIs() - cache := transactions.New() + userAPI := base.UserAPIClient() + rsAPI := base.RoomserverHTTPClient() - appservice.SetupAppServiceAPIComponent( - base, accountDB, deviceDB, federation, alias, query, cache, - ) + intAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + appservice.AddInternalRoutes(base.InternalAPIMux, intAPI) base.SetupAndServeHTTP(string(base.Cfg.Bind.AppServiceAPI), string(base.Cfg.Listen.AppServiceAPI)) diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go index 815a978a8..fe5f30a0e 100644 --- a/cmd/dendrite-client-api-server/main.go +++ b/cmd/dendrite-client-api-server/main.go @@ -16,33 +16,29 @@ package main import ( "github.com/matrix-org/dendrite/clientapi" - "github.com/matrix-org/dendrite/common/basecomponent" - "github.com/matrix-org/dendrite/common/keydb" - "github.com/matrix-org/dendrite/common/transactions" - "github.com/matrix-org/dendrite/eduserver" - "github.com/matrix-org/dendrite/eduserver/cache" + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/dendrite/internal/transactions" ) func main() { - cfg := basecomponent.ParseFlags() + cfg := setup.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "ClientAPI") + base := setup.NewBaseDendrite(cfg, "ClientAPI", true) defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() - keyDB := base.CreateKeyDB() federation := base.CreateFederationClient() - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) - asQuery := base.CreateHTTPAppServiceAPIs() - alias, input, query := base.CreateHTTPRoomserverAPIs() - fedSenderAPI := base.CreateHTTPFederationSenderAPIs() - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) + asQuery := base.AppserviceHTTPClient() + rsAPI := base.RoomserverHTTPClient() + fsAPI := base.FederationSenderHTTPClient() + eduInputAPI := base.EDUServerClient() + userAPI := base.UserAPIClient() - clientapi.SetupClientAPIComponent( - base, deviceDB, accountDB, federation, &keyRing, - alias, input, query, eduInputAPI, asQuery, transactions.New(), fedSenderAPI, + clientapi.AddPublicRoutes( + base.PublicAPIMux, base.Cfg, base.KafkaConsumer, base.KafkaProducer, deviceDB, accountDB, federation, + rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI, ) base.SetupAndServeHTTP(string(base.Cfg.Bind.ClientAPI), string(base.Cfg.Listen.ClientAPI)) diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index f280c7483..356ab5a7f 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -29,40 +29,26 @@ import ( p2phttp "github.com/libp2p/go-libp2p-http" p2pdisc "github.com/libp2p/go-libp2p/p2p/discovery" "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/clientapi" - "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/cmd/dendrite-demo-libp2p/storage" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/dendrite/common/keydb" - "github.com/matrix-org/dendrite/common/transactions" "github.com/matrix-org/dendrite/eduserver" - "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationsender" - "github.com/matrix-org/dendrite/mediaapi" - "github.com/matrix-org/dendrite/publicroomsapi" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/syncapi" + "github.com/matrix-org/dendrite/serverkeyapi" + "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" ) func createKeyDB( base *P2PDendrite, -) keydb.Database { - db, err := keydb.NewDatabase( - string(base.Base.Cfg.Database.ServerKey), - base.Base.Cfg.Matrix.ServerName, - base.Base.Cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey), - base.Base.Cfg.Matrix.KeyID, - ) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to keys db") - } + db gomatrixserverlib.KeyDatabase, +) { mdns := mDNSListener{ host: base.LibP2P, keydb: db, @@ -77,7 +63,6 @@ func createKeyDB( panic(err) } serv.RegisterNotifee(&mdns) - return db } func createFederationClient( @@ -95,6 +80,17 @@ func createFederationClient( ) } +func createClient( + base *P2PDendrite, +) *gomatrixserverlib.Client { + tr := &http.Transport{} + tr.RegisterProtocol( + "matrix", + p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")), + ) + return gomatrixserverlib.NewClientWithTransport(tr) +} + func main() { instanceName := flag.String("name", "dendrite-p2p", "the name of this P2P demo instance") instancePort := flag.Int("port", 8080, "the port that the client API will listen on") @@ -117,6 +113,7 @@ func main() { } cfg := config.Dendrite{} + cfg.SetDefaults() cfg.Matrix.ServerName = "p2p" cfg.Matrix.PrivateKey = privKey cfg.Matrix.KeyID = gomatrixserverlib.KeyID(fmt.Sprintf("ed25519:%s", *instanceName)) @@ -124,7 +121,6 @@ func main() { cfg.Kafka.Topics.OutputRoomEvent = "roomserverOutput" cfg.Kafka.Topics.OutputClientData = "clientapiOutput" cfg.Kafka.Topics.OutputTypingEvent = "typingServerOutput" - cfg.Kafka.Topics.UserUpdates = "userUpdates" cfg.Database.Account = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) cfg.Database.Device = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.Database.MediaAPI = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) @@ -144,44 +140,67 @@ func main() { accountDB := base.Base.CreateAccountsDB() deviceDB := base.Base.CreateDeviceDB() - keyDB := createKeyDB(base) federation := createFederationClient(base) - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil) - alias, input, query := roomserver.SetupRoomServerComponent(&base.Base) - eduInputAPI := eduserver.SetupEDUServerComponent(&base.Base, cache.New()) - asQuery := appservice.SetupAppServiceAPIComponent( - &base.Base, accountDB, deviceDB, federation, alias, query, transactions.New(), + serverKeyAPI := serverkeyapi.NewInternalAPI( + base.Base.Cfg, federation, base.Base.Caches, + ) + keyRing := serverKeyAPI.KeyRing() + createKeyDB( + base, serverKeyAPI, ) - fedSenderAPI := federationsender.SetupFederationSenderComponent(&base.Base, federation, query) - clientapi.SetupClientAPIComponent( - &base.Base, deviceDB, accountDB, - federation, &keyRing, alias, input, query, - eduInputAPI, asQuery, transactions.New(), fedSenderAPI, + rsAPI := roomserver.NewInternalAPI( + &base.Base, keyRing, federation, ) - eduProducer := producers.NewEDUServerProducer(eduInputAPI) - federationapi.SetupFederationAPIComponent(&base.Base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI, eduProducer) - mediaapi.SetupMediaAPIComponent(&base.Base, deviceDB) - publicRoomsDB, err := storage.NewPublicRoomsServerDatabaseWithPubSub(string(base.Base.Cfg.Database.PublicRoomsAPI), base.LibP2PPubsub) + eduInputAPI := eduserver.NewInternalAPI( + &base.Base, cache.New(), userAPI, + ) + asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI) + fsAPI := federationsender.NewInternalAPI( + &base.Base, federation, rsAPI, keyRing, + ) + rsAPI.SetFederationSenderAPI(fsAPI) + publicRoomsDB, err := storage.NewPublicRoomsServerDatabaseWithPubSub(string(base.Base.Cfg.Database.PublicRoomsAPI), base.LibP2PPubsub, cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } - publicroomsapi.SetupPublicRoomsAPIComponent(&base.Base, deviceDB, publicRoomsDB, query, federation, nil) // Check this later - syncapi.SetupSyncAPIComponent(&base.Base, deviceDB, accountDB, query, federation, &cfg) - httpHandler := common.WrapHandlerInCORS(base.Base.APIMux) + monolith := setup.Monolith{ + Config: base.Base.Cfg, + AccountDB: accountDB, + DeviceDB: deviceDB, + Client: createClient(base), + FedClient: federation, + KeyRing: keyRing, + KafkaConsumer: base.Base.KafkaConsumer, + KafkaProducer: base.Base.KafkaProducer, - // Set up the API endpoints we handle. /metrics is for prometheus, and is - // not wrapped by CORS, while everything else is - http.Handle("/metrics", promhttp.Handler()) - http.Handle("/", httpHandler) + AppserviceAPI: asAPI, + EDUInternalAPI: eduInputAPI, + FederationSenderAPI: fsAPI, + RoomserverAPI: rsAPI, + ServerKeyAPI: serverKeyAPI, + UserAPI: userAPI, + + PublicRoomsDB: publicRoomsDB, + } + monolith.AddAllPublicRoutes(base.Base.PublicAPIMux) + + httputil.SetupHTTPAPI( + base.Base.BaseMux, + base.Base.PublicAPIMux, + base.Base.InternalAPIMux, + &cfg, + base.Base.UseHTTPAPIs, + ) // Expose the matrix APIs directly rather than putting them under a /api path. go func() { httpBindAddr := fmt.Sprintf(":%d", *instancePort) logrus.Info("Listening on ", httpBindAddr) - logrus.Fatal(http.ListenAndServe(httpBindAddr, nil)) + logrus.Fatal(http.ListenAndServe(httpBindAddr, base.Base.BaseMux)) }() // Expose the matrix APIs also via libp2p if base.LibP2P != nil { @@ -194,7 +213,7 @@ func main() { defer func() { logrus.Fatal(listener.Close()) }() - logrus.Fatal(http.Serve(listener, nil)) + logrus.Fatal(http.Serve(listener, base.Base.BaseMux)) }() } diff --git a/cmd/dendrite-demo-libp2p/mdnslistener.go b/cmd/dendrite-demo-libp2p/mdnslistener.go index 3fefbec2c..c30aaa331 100644 --- a/cmd/dendrite-demo-libp2p/mdnslistener.go +++ b/cmd/dendrite-demo-libp2p/mdnslistener.go @@ -21,12 +21,11 @@ import ( "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" - "github.com/matrix-org/dendrite/common/keydb" "github.com/matrix-org/gomatrixserverlib" ) type mDNSListener struct { - keydb keydb.Database + keydb gomatrixserverlib.KeyDatabase host host.Host } @@ -44,7 +43,7 @@ func (n *mDNSListener) HandlePeerFound(p peer.AddrInfo) { KeyID: "ed25519:p2pdemo", }: { VerifyKey: gomatrixserverlib.VerifyKey{ - Key: gomatrixserverlib.Base64String(raw), + Key: gomatrixserverlib.Base64Bytes(raw), }, ValidUntilTS: math.MaxUint64 >> 1, ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, diff --git a/cmd/dendrite-demo-libp2p/p2pdendrite.go b/cmd/dendrite-demo-libp2p/p2pdendrite.go index a9db3b39c..4270143f5 100644 --- a/cmd/dendrite-demo-libp2p/p2pdendrite.go +++ b/cmd/dendrite-demo-libp2p/p2pdendrite.go @@ -22,7 +22,7 @@ import ( pstore "github.com/libp2p/go-libp2p-core/peerstore" record "github.com/libp2p/go-libp2p-record" - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/libp2p/go-libp2p" circuit "github.com/libp2p/go-libp2p-circuit" @@ -34,12 +34,12 @@ import ( pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" ) // P2PDendrite is a Peer-to-Peer variant of BaseDendrite. type P2PDendrite struct { - Base basecomponent.BaseDendrite + Base setup.BaseDendrite // Store our libp2p object so that we can make outgoing connections from it // later @@ -54,7 +54,7 @@ type P2PDendrite struct { // The componentName is used for logging purposes, and should be a friendly name // of the component running, e.g. SyncAPI. func NewP2PDendrite(cfg *config.Dendrite, componentName string) *P2PDendrite { - baseDendrite := basecomponent.NewBaseDendrite(cfg, componentName) + baseDendrite := setup.NewBaseDendrite(cfg, componentName, false) ctx, cancel := context.WithCancel(context.Background()) diff --git a/cmd/dendrite-demo-libp2p/storage/postgreswithdht/storage.go b/cmd/dendrite-demo-libp2p/storage/postgreswithdht/storage.go index 819469ee8..d2cb36a8b 100644 --- a/cmd/dendrite-demo-libp2p/storage/postgreswithdht/storage.go +++ b/cmd/dendrite-demo-libp2p/storage/postgreswithdht/storage.go @@ -44,8 +44,8 @@ type PublicRoomsServerDatabase struct { } // NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string, dht *dht.IpfsDHT) (*PublicRoomsServerDatabase, error) { - pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName) +func NewPublicRoomsServerDatabase(dataSourceName string, dht *dht.IpfsDHT, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { + pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName, nil, localServerName) if err != nil { return nil, err } diff --git a/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub/storage.go b/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub/storage.go index 661192243..cf642eb38 100644 --- a/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub/storage.go +++ b/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub/storage.go @@ -47,8 +47,8 @@ type PublicRoomsServerDatabase struct { } // NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string, pubsub *pubsub.PubSub) (*PublicRoomsServerDatabase, error) { - pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName) +func NewPublicRoomsServerDatabase(dataSourceName string, pubsub *pubsub.PubSub, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { + pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName, nil, localServerName) if err != nil { return nil, err } diff --git a/cmd/dendrite-demo-libp2p/storage/storage.go b/cmd/dendrite-demo-libp2p/storage/storage.go index 668edbaa3..2d8dc1817 100644 --- a/cmd/dendrite-demo-libp2p/storage/storage.go +++ b/cmd/dendrite-demo-libp2p/storage/storage.go @@ -23,39 +23,40 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/publicroomsapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" ) const schemePostgres = "postgres" const schemeFile = "file" // NewPublicRoomsServerDatabase opens a database connection. -func NewPublicRoomsServerDatabaseWithDHT(dataSourceName string, dht *dht.IpfsDHT) (storage.Database, error) { +func NewPublicRoomsServerDatabaseWithDHT(dataSourceName string, dht *dht.IpfsDHT, localServerName gomatrixserverlib.ServerName) (storage.Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht) + return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht, localServerName) } switch uri.Scheme { case schemePostgres: - return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht) + return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht, localServerName) case schemeFile: - return sqlite3.NewPublicRoomsServerDatabase(dataSourceName) + return sqlite3.NewPublicRoomsServerDatabase(dataSourceName, localServerName) default: - return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht) + return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht, localServerName) } } // NewPublicRoomsServerDatabase opens a database connection. -func NewPublicRoomsServerDatabaseWithPubSub(dataSourceName string, pubsub *pubsub.PubSub) (storage.Database, error) { +func NewPublicRoomsServerDatabaseWithPubSub(dataSourceName string, pubsub *pubsub.PubSub, localServerName gomatrixserverlib.ServerName) (storage.Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub) + return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub, localServerName) } switch uri.Scheme { case schemePostgres: - return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub) + return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub, localServerName) case schemeFile: - return sqlite3.NewPublicRoomsServerDatabase(dataSourceName) + return sqlite3.NewPublicRoomsServerDatabase(dataSourceName, localServerName) default: - return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub) + return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub, localServerName) } } diff --git a/cmd/dendrite-demo-yggdrasil/convert/25519.go b/cmd/dendrite-demo-yggdrasil/convert/25519.go new file mode 100644 index 000000000..97f053ec0 --- /dev/null +++ b/cmd/dendrite-demo-yggdrasil/convert/25519.go @@ -0,0 +1,53 @@ +// Copyright 2019 Google LLC +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd +// +// Original code from https://github.com/FiloSottile/age/blob/bbab440e198a4d67ba78591176c7853e62d29e04/internal/age/ssh.go + +package convert + +import ( + "crypto/ed25519" + "crypto/sha512" + "math/big" + + "golang.org/x/crypto/curve25519" +) + +var curve25519P, _ = new(big.Int).SetString("57896044618658097711785492504343953926634992332820282019728792003956564819949", 10) + +func Ed25519PrivateKeyToCurve25519(pk ed25519.PrivateKey) []byte { + h := sha512.New() + _, _ = h.Write(pk.Seed()) + out := h.Sum(nil) + return out[:curve25519.ScalarSize] +} + +func Ed25519PublicKeyToCurve25519(pk ed25519.PublicKey) []byte { + // ed25519.PublicKey is a little endian representation of the y-coordinate, + // with the most significant bit set based on the sign of the x-coordinate. + bigEndianY := make([]byte, ed25519.PublicKeySize) + for i, b := range pk { + bigEndianY[ed25519.PublicKeySize-i-1] = b + } + bigEndianY[0] &= 0b0111_1111 + + // The Montgomery u-coordinate is derived through the bilinear map + // u = (1 + y) / (1 - y) + // See https://blog.filippo.io/using-ed25519-keys-for-encryption. + y := new(big.Int).SetBytes(bigEndianY) + denom := big.NewInt(1) + denom.ModInverse(denom.Sub(denom, y), curve25519P) // 1 / (1 - y) + u := y.Mul(y.Add(y, big.NewInt(1)), denom) + u.Mod(u, curve25519P) + + out := make([]byte, curve25519.PointSize) + uBytes := u.Bytes() + for i, b := range uBytes { + out[len(uBytes)-i-1] = b + } + + return out +} diff --git a/cmd/dendrite-demo-yggdrasil/convert/25519_test.go b/cmd/dendrite-demo-yggdrasil/convert/25519_test.go new file mode 100644 index 000000000..22177b8b4 --- /dev/null +++ b/cmd/dendrite-demo-yggdrasil/convert/25519_test.go @@ -0,0 +1,51 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package convert + +import ( + "bytes" + "crypto/ed25519" + "encoding/hex" + "testing" + + "golang.org/x/crypto/curve25519" +) + +func TestKeyConversion(t *testing.T) { + edPub, edPriv, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + t.Log("Signing public:", hex.EncodeToString(edPub)) + t.Log("Signing private:", hex.EncodeToString(edPriv)) + + cuPriv := Ed25519PrivateKeyToCurve25519(edPriv) + t.Log("Encryption private:", hex.EncodeToString(cuPriv)) + + cuPub := Ed25519PublicKeyToCurve25519(edPub) + t.Log("Converted encryption public:", hex.EncodeToString(cuPub)) + + var realPub, realPriv [32]byte + copy(realPriv[:32], cuPriv[:32]) + curve25519.ScalarBaseMult(&realPub, &realPriv) + t.Log("Scalar-multed encryption public:", hex.EncodeToString(realPub[:])) + + if !bytes.Equal(realPriv[:], cuPriv[:]) { + t.Fatal("Private keys should be equal (this means the test is broken)") + } + if !bytes.Equal(realPub[:], cuPub[:]) { + t.Fatal("Public keys should be equal") + } +} diff --git a/cmd/dendrite-demo-yggdrasil/embed/embed_other.go b/cmd/dendrite-demo-yggdrasil/embed/embed_other.go new file mode 100644 index 000000000..598881148 --- /dev/null +++ b/cmd/dendrite-demo-yggdrasil/embed/embed_other.go @@ -0,0 +1,9 @@ +// +build !riotweb + +package embed + +import "github.com/gorilla/mux" + +func Embed(_ *mux.Router, _ int, _ string) { + +} diff --git a/cmd/dendrite-demo-yggdrasil/embed/embed_riotweb.go b/cmd/dendrite-demo-yggdrasil/embed/embed_riotweb.go new file mode 100644 index 000000000..a9e04a312 --- /dev/null +++ b/cmd/dendrite-demo-yggdrasil/embed/embed_riotweb.go @@ -0,0 +1,62 @@ +// +build riotweb + +package embed + +import ( + "fmt" + "io" + "net/http" + + "github.com/gorilla/mux" + "github.com/tidwall/sjson" +) + +// From within the Riot Web directory: +// go run github.com/mjibson/esc -o /path/to/dendrite/internal/embed/fs_riotweb.go -private -pkg embed . + +func Embed(rootMux *mux.Router, listenPort int, serverName string) { + url := fmt.Sprintf("http://localhost:%d", listenPort) + embeddedFS := _escFS(false) + embeddedServ := http.FileServer(embeddedFS) + + rootMux.Handle("/", embeddedServ) + rootMux.HandleFunc("/config.json", func(w http.ResponseWriter, _ *http.Request) { + configFile, err := embeddedFS.Open("/config.sample.json") + if err != nil { + w.WriteHeader(500) + io.WriteString(w, "Couldn't open the file: "+err.Error()) + return + } + configFileInfo, err := configFile.Stat() + if err != nil { + w.WriteHeader(500) + io.WriteString(w, "Couldn't stat the file: "+err.Error()) + return + } + buf := make([]byte, configFileInfo.Size()) + n, err := configFile.Read(buf) + if err != nil { + w.WriteHeader(500) + io.WriteString(w, "Couldn't read the file: "+err.Error()) + return + } + if int64(n) != configFileInfo.Size() { + w.WriteHeader(500) + io.WriteString(w, "The returned file size didn't match what we expected") + return + } + js, _ := sjson.SetBytes(buf, "default_server_config.m\\.homeserver.base_url", url) + js, _ = sjson.SetBytes(js, "default_server_config.m\\.homeserver.server_name", serverName) + js, _ = sjson.SetBytes(js, "brand", fmt.Sprintf("Riot %s", serverName)) + js, _ = sjson.SetBytes(js, "disable_guests", true) + js, _ = sjson.SetBytes(js, "disable_3pid_login", true) + js, _ = sjson.DeleteBytes(js, "welcomeUserId") + _, _ = w.Write(js) + }) + + fmt.Println("*-------------------------------*") + fmt.Println("| This build includes Riot Web! |") + fmt.Println("*-------------------------------*") + fmt.Println("Point your browser to:", url) + fmt.Println() +} diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go new file mode 100644 index 000000000..db05ecb76 --- /dev/null +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -0,0 +1,171 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "crypto/tls" + "flag" + "fmt" + "net" + "net/http" + "time" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/embed" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/yggconn" + "github.com/matrix-org/dendrite/eduserver" + "github.com/matrix-org/dendrite/eduserver/cache" + "github.com/matrix-org/dendrite/federationsender" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/dendrite/publicroomsapi/storage" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/sirupsen/logrus" +) + +var ( + instanceName = flag.String("name", "dendrite-p2p-ygg", "the name of this P2P demo instance") + instancePort = flag.Int("port", 8008, "the port that the client API will listen on") + instancePeer = flag.String("peer", "", "an internet Yggdrasil peer to connect to") +) + +// nolint:gocyclo +func main() { + flag.Parse() + + ygg, err := yggconn.Setup(*instanceName, *instancePeer, ".") + if err != nil { + panic(err) + } + + cfg := &config.Dendrite{} + cfg.SetDefaults() + cfg.Matrix.ServerName = gomatrixserverlib.ServerName(ygg.DerivedServerName()) + cfg.Matrix.PrivateKey = ygg.SigningPrivateKey() + cfg.Matrix.KeyID = gomatrixserverlib.KeyID(signing.KeyID) + cfg.Kafka.UseNaffka = true + cfg.Kafka.Topics.OutputRoomEvent = "roomserverOutput" + cfg.Kafka.Topics.OutputClientData = "clientapiOutput" + cfg.Kafka.Topics.OutputTypingEvent = "typingServerOutput" + cfg.Database.Account = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) + cfg.Database.Device = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) + cfg.Database.MediaAPI = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) + cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) + cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) + cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName)) + cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName)) + cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) + cfg.Database.PublicRoomsAPI = config.DataSource(fmt.Sprintf("file:%s-publicroomsa.db", *instanceName)) + cfg.Database.Naffka = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) + if err = cfg.Derive(); err != nil { + panic(err) + } + + base := setup.NewBaseDendrite(cfg, "Monolith", false) + defer base.Close() // nolint: errcheck + + accountDB := base.CreateAccountsDB() + deviceDB := base.CreateDeviceDB() + federation := ygg.CreateFederationClient(base) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil) + + rsComponent := roomserver.NewInternalAPI( + base, keyRing, federation, + ) + rsAPI := rsComponent + + eduInputAPI := eduserver.NewInternalAPI( + base, cache.New(), userAPI, + ) + + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + + fsAPI := federationsender.NewInternalAPI( + base, federation, rsAPI, keyRing, + ) + + rsComponent.SetFederationSenderAPI(fsAPI) + + publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), base.Cfg.DbProperties(), cfg.Matrix.ServerName) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to public rooms db") + } + + embed.Embed(base.BaseMux, *instancePort, "Yggdrasil Demo") + + monolith := setup.Monolith{ + Config: base.Cfg, + AccountDB: accountDB, + DeviceDB: deviceDB, + Client: ygg.CreateClient(base), + FedClient: federation, + KeyRing: keyRing, + KafkaConsumer: base.KafkaConsumer, + KafkaProducer: base.KafkaProducer, + + AppserviceAPI: asAPI, + EDUInternalAPI: eduInputAPI, + FederationSenderAPI: fsAPI, + RoomserverAPI: rsAPI, + UserAPI: userAPI, + //ServerKeyAPI: serverKeyAPI, + + PublicRoomsDB: publicRoomsDB, + } + monolith.AddAllPublicRoutes(base.PublicAPIMux) + + httputil.SetupHTTPAPI( + base.BaseMux, + base.PublicAPIMux, + base.InternalAPIMux, + cfg, + base.UseHTTPAPIs, + ) + + // Build both ends of a HTTP multiplex. + httpServer := &http.Server{ + Addr: ":0", + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, + ReadTimeout: 15 * time.Second, + WriteTimeout: 45 * time.Second, + IdleTimeout: 60 * time.Second, + BaseContext: func(_ net.Listener) context.Context { + return context.Background() + }, + Handler: base.BaseMux, + } + + go func() { + logrus.Info("Listening on ", ygg.DerivedServerName()) + logrus.Fatal(httpServer.Serve(ygg)) + }() + go func() { + httpBindAddr := fmt.Sprintf(":%d", *instancePort) + logrus.Info("Listening on ", httpBindAddr) + logrus.Fatal(http.ListenAndServe(httpBindAddr, base.BaseMux)) + }() + + select {} +} diff --git a/cmd/dendrite-demo-yggdrasil/signing/fetcher.go b/cmd/dendrite-demo-yggdrasil/signing/fetcher.go new file mode 100644 index 000000000..bcec0cbec --- /dev/null +++ b/cmd/dendrite-demo-yggdrasil/signing/fetcher.go @@ -0,0 +1,69 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signing + +import ( + "context" + "encoding/hex" + "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" +) + +const KeyID = "ed25519:dendrite-demo-yggdrasil" + +type YggdrasilKeys struct { +} + +func (f *YggdrasilKeys) KeyRing() *gomatrixserverlib.KeyRing { + return &gomatrixserverlib.KeyRing{ + KeyDatabase: f, + } +} + +func (f *YggdrasilKeys) FetchKeys( + ctx context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + res := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) + for req := range requests { + if req.KeyID != KeyID { + return nil, fmt.Errorf("FetchKeys: cannot fetch key with ID %s, should be %s", req.KeyID, KeyID) + } + + hexkey, err := hex.DecodeString(string(req.ServerName)) + if err != nil { + return nil, fmt.Errorf("FetchKeys: can't decode server name %q: %w", req.ServerName, err) + } + + res[req] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: gomatrixserverlib.VerifyKey{ + Key: hexkey, + }, + ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, + ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(24 * time.Hour * 365)), + } + } + return res, nil +} + +func (f *YggdrasilKeys) FetcherName() string { + return "YggdrasilKeys" +} + +func (f *YggdrasilKeys) StoreKeys(ctx context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error { + return nil +} diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go new file mode 100644 index 000000000..399993e3e --- /dev/null +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -0,0 +1,74 @@ +package yggconn + +import ( + "context" + "crypto/ed25519" + "encoding/hex" + "fmt" + "net" + "net/http" + "strings" + "time" + + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/convert" + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/gomatrixserverlib" +) + +func (n *Node) yggdialer(_, address string) (net.Conn, error) { + tokens := strings.Split(address, ":") + raw, err := hex.DecodeString(tokens[0]) + if err != nil { + return nil, fmt.Errorf("hex.DecodeString: %w", err) + } + converted := convert.Ed25519PublicKeyToCurve25519(ed25519.PublicKey(raw)) + convhex := hex.EncodeToString(converted) + return n.Dial("curve25519", convhex) +} + +func (n *Node) yggdialerctx(ctx context.Context, network, address string) (net.Conn, error) { + return n.yggdialer(network, address) +} + +type yggroundtripper struct { + inner *http.Transport +} + +func (y *yggroundtripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.URL.Scheme = "http" + return y.inner.RoundTrip(req) +} + +func (n *Node) CreateClient( + base *setup.BaseDendrite, +) *gomatrixserverlib.Client { + tr := &http.Transport{} + tr.RegisterProtocol( + "matrix", &yggroundtripper{ + inner: &http.Transport{ + ResponseHeaderTimeout: 15 * time.Second, + IdleConnTimeout: 60 * time.Second, + DialContext: n.yggdialerctx, + }, + }, + ) + return gomatrixserverlib.NewClientWithTransport(tr) +} + +func (n *Node) CreateFederationClient( + base *setup.BaseDendrite, +) *gomatrixserverlib.FederationClient { + tr := &http.Transport{} + tr.RegisterProtocol( + "matrix", &yggroundtripper{ + inner: &http.Transport{ + ResponseHeaderTimeout: 15 * time.Second, + IdleConnTimeout: 60 * time.Second, + DialContext: n.yggdialerctx, + }, + }, + ) + return gomatrixserverlib.NewFederationClientWithTransport( + base.Cfg.Matrix.ServerName, base.Cfg.Matrix.KeyID, base.Cfg.Matrix.PrivateKey, tr, + ) +} diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/node.go b/cmd/dendrite-demo-yggdrasil/yggconn/node.go new file mode 100644 index 000000000..c335f2eac --- /dev/null +++ b/cmd/dendrite-demo-yggdrasil/yggconn/node.go @@ -0,0 +1,176 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yggconn + +import ( + "context" + "crypto/ed25519" + "encoding/hex" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "strings" + "sync" + + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/convert" + + "github.com/libp2p/go-yamux" + yggdrasiladmin "github.com/yggdrasil-network/yggdrasil-go/src/admin" + yggdrasilconfig "github.com/yggdrasil-network/yggdrasil-go/src/config" + yggdrasilmulticast "github.com/yggdrasil-network/yggdrasil-go/src/multicast" + "github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil" + + gologme "github.com/gologme/log" +) + +type Node struct { + core *yggdrasil.Core + config *yggdrasilconfig.NodeConfig + state *yggdrasilconfig.NodeState + admin *yggdrasiladmin.AdminSocket + multicast *yggdrasilmulticast.Multicast + log *gologme.Logger + listener *yggdrasil.Listener + dialer *yggdrasil.Dialer + sessions sync.Map // string -> yamux.Session + incoming chan *yamux.Stream +} + +func (n *Node) Dialer(_, address string) (net.Conn, error) { + tokens := strings.Split(address, ":") + raw, err := hex.DecodeString(tokens[0]) + if err != nil { + return nil, fmt.Errorf("hex.DecodeString: %w", err) + } + converted := convert.Ed25519PublicKeyToCurve25519(ed25519.PublicKey(raw)) + convhex := hex.EncodeToString(converted) + return n.Dial("curve25519", convhex) +} + +func (n *Node) DialerContext(ctx context.Context, network, address string) (net.Conn, error) { + return n.Dialer(network, address) +} + +// nolint:gocyclo +func Setup(instanceName, instancePeer, storageDirectory string) (*Node, error) { + n := &Node{ + core: &yggdrasil.Core{}, + config: yggdrasilconfig.GenerateConfig(), + admin: &yggdrasiladmin.AdminSocket{}, + multicast: &yggdrasilmulticast.Multicast{}, + log: gologme.New(os.Stdout, "YGG ", log.Flags()), + incoming: make(chan *yamux.Stream), + } + + yggfile := fmt.Sprintf("%s/%s-yggdrasil.conf", storageDirectory, instanceName) + if _, err := os.Stat(yggfile); !os.IsNotExist(err) { + yggconf, e := ioutil.ReadFile(yggfile) + if e != nil { + panic(err) + } + if err := json.Unmarshal([]byte(yggconf), &n.config); err != nil { + panic(err) + } + } else { + n.config.AdminListen = "none" // fmt.Sprintf("unix://%s/%s-yggdrasil.sock", storageDirectory, instanceName) + n.config.MulticastInterfaces = []string{".*"} + n.config.EncryptionPrivateKey = hex.EncodeToString(n.EncryptionPrivateKey()) + n.config.EncryptionPublicKey = hex.EncodeToString(n.EncryptionPublicKey()) + + j, err := json.MarshalIndent(n.config, "", " ") + if err != nil { + panic(err) + } + if e := ioutil.WriteFile(yggfile, j, 0600); e != nil { + n.log.Printf("Couldn't write private key to file '%s': %s\n", yggfile, e) + } + } + + var err error + n.log.EnableLevel("error") + n.log.EnableLevel("warn") + n.log.EnableLevel("info") + n.state, err = n.core.Start(n.config, n.log) + if err != nil { + panic(err) + } + if instancePeer != "" { + if err = n.core.AddPeer(instancePeer, ""); err != nil { + panic(err) + } + } + /* + if err = n.admin.Init(n.core, n.state, n.log, nil); err != nil { + panic(err) + } + if err = n.admin.Start(); err != nil { + panic(err) + } + */ + if err = n.multicast.Init(n.core, n.state, n.log, nil); err != nil { + panic(err) + } + if err = n.multicast.Start(); err != nil { + panic(err) + } + //n.admin.SetupAdminHandlers(n.admin) + //n.multicast.SetupAdminHandlers(n.admin) + n.listener, err = n.core.ConnListen() + if err != nil { + panic(err) + } + n.dialer, err = n.core.ConnDialer() + if err != nil { + panic(err) + } + + n.log.Println("Public curve25519:", n.core.EncryptionPublicKey()) + n.log.Println("Public ed25519:", n.core.SigningPublicKey()) + + go n.listenFromYgg() + + return n, nil +} + +func (n *Node) DerivedServerName() string { + return hex.EncodeToString(n.SigningPublicKey()) +} + +func (n *Node) DerivedSessionName() string { + return hex.EncodeToString(n.EncryptionPublicKey()) +} + +func (n *Node) EncryptionPublicKey() []byte { + edkey := n.SigningPublicKey() + return convert.Ed25519PublicKeyToCurve25519(edkey) +} + +func (n *Node) EncryptionPrivateKey() []byte { + edkey := n.SigningPrivateKey() + return convert.Ed25519PrivateKeyToCurve25519(edkey) +} + +func (n *Node) SigningPublicKey() ed25519.PublicKey { + pubBytes, _ := hex.DecodeString(n.config.SigningPublicKey) + return ed25519.PublicKey(pubBytes) +} + +func (n *Node) SigningPrivateKey() ed25519.PrivateKey { + privBytes, _ := hex.DecodeString(n.config.SigningPrivateKey) + return ed25519.PrivateKey(privBytes) +} diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/session.go b/cmd/dendrite-demo-yggdrasil/yggconn/session.go new file mode 100644 index 000000000..c50b6b73c --- /dev/null +++ b/cmd/dendrite-demo-yggdrasil/yggconn/session.go @@ -0,0 +1,124 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yggconn + +import ( + "context" + "net" + "strings" + "time" + + "github.com/libp2p/go-yamux" +) + +func (n *Node) yamuxConfig() *yamux.Config { + cfg := yamux.DefaultConfig() + cfg.EnableKeepAlive = false + cfg.ConnectionWriteTimeout = time.Second * 15 + cfg.MaxMessageSize = 65535 + cfg.ReadBufSize = 655350 + return cfg +} + +func (n *Node) listenFromYgg() { + for { + conn, err := n.listener.Accept() + if err != nil { + n.log.Println("n.listener.Accept:", err) + return + } + var session *yamux.Session + // If the remote address is lower than ours then we'll be the + // server. Otherwse we'll be the client. + if strings.Compare(conn.RemoteAddr().String(), n.DerivedSessionName()) < 0 { + session, err = yamux.Server(conn, n.yamuxConfig()) + } else { + session, err = yamux.Client(conn, n.yamuxConfig()) + } + if err != nil { + return + } + go n.listenFromYggConn(session) + } +} + +func (n *Node) listenFromYggConn(session *yamux.Session) { + n.sessions.Store(session.RemoteAddr().String(), session) + defer n.sessions.Delete(session.RemoteAddr()) + defer func() { + if err := session.Close(); err != nil { + n.log.Println("session.Close:", err) + } + }() + + for { + st, err := session.AcceptStream() + if err != nil { + n.log.Println("session.AcceptStream:", err) + return + } + n.incoming <- st + } +} + +// Implements net.Listener +func (n *Node) Accept() (net.Conn, error) { + return <-n.incoming, nil +} + +// Implements net.Listener +func (n *Node) Close() error { + return n.listener.Close() +} + +// Implements net.Listener +func (n *Node) Addr() net.Addr { + return n.listener.Addr() +} + +// Implements http.Transport.Dial +func (n *Node) Dial(network, address string) (net.Conn, error) { + return n.DialContext(context.TODO(), network, address) +} + +// Implements http.Transport.DialContext +func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + s, ok1 := n.sessions.Load(address) + session, ok2 := s.(*yamux.Session) + if !ok1 || !ok2 || (ok1 && ok2 && session.IsClosed()) { + conn, err := n.dialer.DialContext(ctx, network, address) + if err != nil { + n.log.Println("n.dialer.DialContext:", err) + return nil, err + } + // If the remote address is lower than ours then we will be the + // server. Otherwise we'll be the client. + if strings.Compare(conn.RemoteAddr().String(), n.DerivedSessionName()) < 0 { + session, err = yamux.Server(conn, n.yamuxConfig()) + } else { + session, err = yamux.Client(conn, n.yamuxConfig()) + } + if err != nil { + return nil, err + } + go n.listenFromYggConn(session) + } + st, err := session.OpenStream() + if err != nil { + n.log.Println("session.OpenStream:", err) + return nil, err + } + return st, nil +} diff --git a/cmd/dendrite-edu-server/main.go b/cmd/dendrite-edu-server/main.go index a4511f1ba..6704ebd09 100644 --- a/cmd/dendrite-edu-server/main.go +++ b/cmd/dendrite-edu-server/main.go @@ -15,22 +15,23 @@ package main import ( _ "net/http/pprof" - "github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" + "github.com/matrix-org/dendrite/internal/setup" "github.com/sirupsen/logrus" ) func main() { - cfg := basecomponent.ParseFlags() - base := basecomponent.NewBaseDendrite(cfg, "EDUServerAPI") + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "EDUServerAPI", true) defer func() { if err := base.Close(); err != nil { logrus.WithError(err).Warn("BaseDendrite close failed") } }() - eduserver.SetupEDUServerComponent(base, cache.New()) + intAPI := eduserver.NewInternalAPI(base, cache.New(), base.UserAPIClient()) + eduserver.AddInternalRoutes(base.InternalAPIMux, intAPI) base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer)) diff --git a/cmd/dendrite-federation-api-server/main.go b/cmd/dendrite-federation-api-server/main.go index dd06cd3f9..e3bf5edc8 100644 --- a/cmd/dendrite-federation-api-server/main.go +++ b/cmd/dendrite-federation-api-server/main.go @@ -15,34 +15,25 @@ package main import ( - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common/basecomponent" - "github.com/matrix-org/dendrite/common/keydb" - "github.com/matrix-org/dendrite/eduserver" - "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/internal/setup" ) func main() { - cfg := basecomponent.ParseFlags() - base := basecomponent.NewBaseDendrite(cfg, "FederationAPI") + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "FederationAPI", true) defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() - keyDB := base.CreateKeyDB() + userAPI := base.UserAPIClient() federation := base.CreateFederationClient() - federationSender := base.CreateHTTPFederationSenderAPIs() - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) + serverKeyAPI := base.ServerKeyAPIClient() + keyRing := serverKeyAPI.KeyRing() + fsAPI := base.FederationSenderHTTPClient() + rsAPI := base.RoomserverHTTPClient() - alias, input, query := base.CreateHTTPRoomserverAPIs() - asQuery := base.CreateHTTPAppServiceAPIs() - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) - eduProducer := producers.NewEDUServerProducer(eduInputAPI) - - federationapi.SetupFederationAPIComponent( - base, accountDB, deviceDB, federation, &keyRing, - alias, input, query, asQuery, federationSender, eduProducer, + federationapi.AddPublicRoutes( + base.PublicAPIMux, base.Cfg, userAPI, federation, keyRing, + rsAPI, fsAPI, base.EDUServerClient(), ) base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI)) diff --git a/cmd/dendrite-federation-sender-server/main.go b/cmd/dendrite-federation-sender-server/main.go index 71fc0b015..20bc1070f 100644 --- a/cmd/dendrite-federation-sender-server/main.go +++ b/cmd/dendrite-federation-sender-server/main.go @@ -15,22 +15,25 @@ package main import ( - "github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/federationsender" + "github.com/matrix-org/dendrite/internal/setup" ) func main() { - cfg := basecomponent.ParseFlags() - base := basecomponent.NewBaseDendrite(cfg, "FederationSender") + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "FederationSender", true) defer base.Close() // nolint: errcheck federation := base.CreateFederationClient() - _, _, query := base.CreateHTTPRoomserverAPIs() + serverKeyAPI := base.ServerKeyAPIClient() + keyRing := serverKeyAPI.KeyRing() - federationsender.SetupFederationSenderComponent( - base, federation, query, + rsAPI := base.RoomserverHTTPClient() + fsAPI := federationsender.NewInternalAPI( + base, federation, rsAPI, keyRing, ) + federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationSender), string(base.Cfg.Listen.FederationSender)) diff --git a/cmd/dendrite-key-server/main.go b/cmd/dendrite-key-server/main.go new file mode 100644 index 000000000..b557cbd9e --- /dev/null +++ b/cmd/dendrite-key-server/main.go @@ -0,0 +1,33 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/dendrite/keyserver" +) + +func main() { + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "KeyServer", true) + defer base.Close() // nolint: errcheck + + userAPI := base.UserAPIClient() + + keyserver.AddPublicRoutes(base.PublicAPIMux, base.Cfg, userAPI) + + base.SetupAndServeHTTP(string(base.Cfg.Bind.KeyServer), string(base.Cfg.Listen.KeyServer)) + +} diff --git a/cmd/dendrite-media-api-server/main.go b/cmd/dendrite-media-api-server/main.go index a818db73a..1582a33a8 100644 --- a/cmd/dendrite-media-api-server/main.go +++ b/cmd/dendrite-media-api-server/main.go @@ -15,18 +15,20 @@ package main import ( - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/mediaapi" + "github.com/matrix-org/gomatrixserverlib" ) func main() { - cfg := basecomponent.ParseFlags() - base := basecomponent.NewBaseDendrite(cfg, "MediaAPI") + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "MediaAPI", true) defer base.Close() // nolint: errcheck - deviceDB := base.CreateDeviceDB() + userAPI := base.UserAPIClient() + client := gomatrixserverlib.NewClient() - mediaapi.SetupMediaAPIComponent(base, deviceDB) + mediaapi.AddPublicRoutes(base.PublicAPIMux, base.Cfg, userAPI, client) base.SetupAndServeHTTP(string(base.Cfg.Bind.MediaAPI), string(base.Cfg.Listen.MediaAPI)) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 6b0d83ae1..339bbe699 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -17,82 +17,146 @@ package main import ( "flag" "net/http" + "os" "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/clientapi" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/basecomponent" - "github.com/matrix-org/dendrite/common/keydb" - "github.com/matrix-org/dendrite/common/transactions" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationsender" - "github.com/matrix-org/dendrite/mediaapi" - "github.com/matrix-org/dendrite/publicroomsapi" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/syncapi" - "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/serverkeyapi" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) var ( - httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") - httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") - certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") - keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") + httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") + httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") + certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") + keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") + enableHTTPAPIs = flag.Bool("api", false, "Use HTTP APIs instead of short-circuiting (warning: exposes API endpoints!)") + traceInternal = os.Getenv("DENDRITE_TRACE_INTERNAL") == "1" ) func main() { - cfg := basecomponent.ParseMonolithFlags() - base := basecomponent.NewBaseDendrite(cfg, "Monolith") + cfg := setup.ParseFlags(true) + if *enableHTTPAPIs { + // If the HTTP APIs are enabled then we need to update the Listen + // statements in the configuration so that we know where to find + // the API endpoints. They'll listen on the same port as the monolith + // itself. + addr := config.Address(*httpBindAddr) + cfg.Listen.RoomServer = addr + cfg.Listen.EDUServer = addr + cfg.Listen.AppServiceAPI = addr + cfg.Listen.FederationSender = addr + cfg.Listen.ServerKeyAPI = addr + } + + base := setup.NewBaseDendrite(cfg, "Monolith", *enableHTTPAPIs) defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() - keyDB := base.CreateKeyDB() federation := base.CreateFederationClient() - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) - alias, input, query := roomserver.SetupRoomServerComponent(base) - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) - asQuery := appservice.SetupAppServiceAPIComponent( - base, accountDB, deviceDB, federation, alias, query, transactions.New(), + serverKeyAPI := serverkeyapi.NewInternalAPI( + base.Cfg, federation, base.Caches, ) - fedSenderAPI := federationsender.SetupFederationSenderComponent(base, federation, query) + if base.UseHTTPAPIs { + serverkeyapi.AddInternalRoutes(base.InternalAPIMux, serverKeyAPI, base.Caches) + serverKeyAPI = base.ServerKeyAPIClient() + } + keyRing := serverKeyAPI.KeyRing() + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices) - clientapi.SetupClientAPIComponent( - base, deviceDB, accountDB, - federation, &keyRing, alias, input, query, - eduInputAPI, asQuery, transactions.New(), fedSenderAPI, + rsImpl := roomserver.NewInternalAPI( + base, keyRing, federation, ) - eduProducer := producers.NewEDUServerProducer(eduInputAPI) - federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI, eduProducer) - mediaapi.SetupMediaAPIComponent(base, deviceDB) - publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI)) + // call functions directly on the impl unless running in HTTP mode + rsAPI := rsImpl + if base.UseHTTPAPIs { + roomserver.AddInternalRoutes(base.InternalAPIMux, rsImpl) + rsAPI = base.RoomserverHTTPClient() + } + if traceInternal { + rsAPI = &api.RoomserverInternalAPITrace{ + Impl: rsAPI, + } + } + + eduInputAPI := eduserver.NewInternalAPI( + base, cache.New(), userAPI, + ) + if base.UseHTTPAPIs { + eduserver.AddInternalRoutes(base.InternalAPIMux, eduInputAPI) + eduInputAPI = base.EDUServerClient() + } + + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + if base.UseHTTPAPIs { + appservice.AddInternalRoutes(base.InternalAPIMux, asAPI) + asAPI = base.AppserviceHTTPClient() + } + + fsAPI := federationsender.NewInternalAPI( + base, federation, rsAPI, keyRing, + ) + if base.UseHTTPAPIs { + federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) + fsAPI = base.FederationSenderHTTPClient() + } + // The underlying roomserver implementation needs to be able to call the fedsender. + // This is different to rsAPI which can be the http client which doesn't need this dependency + rsImpl.SetFederationSenderAPI(fsAPI) + + publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), base.Cfg.DbProperties(), cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } - publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, publicRoomsDB, query, federation, nil) - syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg) - httpHandler := common.WrapHandlerInCORS(base.APIMux) + monolith := setup.Monolith{ + Config: base.Cfg, + AccountDB: accountDB, + DeviceDB: deviceDB, + Client: gomatrixserverlib.NewClient(), + FedClient: federation, + KeyRing: keyRing, + KafkaConsumer: base.KafkaConsumer, + KafkaProducer: base.KafkaProducer, - // Set up the API endpoints we handle. /metrics is for prometheus, and is - // not wrapped by CORS, while everything else is - if cfg.Metrics.Enabled { - http.Handle("/metrics", common.WrapHandlerInBasicAuth(promhttp.Handler(), cfg.Metrics.BasicAuth)) + AppserviceAPI: asAPI, + EDUInternalAPI: eduInputAPI, + FederationSenderAPI: fsAPI, + RoomserverAPI: rsAPI, + ServerKeyAPI: serverKeyAPI, + UserAPI: userAPI, + + PublicRoomsDB: publicRoomsDB, } - http.Handle("/", httpHandler) + monolith.AddAllPublicRoutes(base.PublicAPIMux) + + httputil.SetupHTTPAPI( + base.BaseMux, + base.PublicAPIMux, + base.InternalAPIMux, + cfg, + base.UseHTTPAPIs, + ) // Expose the matrix APIs directly rather than putting them under a /api path. go func() { serv := http.Server{ Addr: *httpBindAddr, - WriteTimeout: basecomponent.HTTPServerTimeout, + WriteTimeout: setup.HTTPServerTimeout, + Handler: base.BaseMux, } logrus.Info("Listening on ", serv.Addr) @@ -103,7 +167,8 @@ func main() { go func() { serv := http.Server{ Addr: *httpsBindAddr, - WriteTimeout: basecomponent.HTTPServerTimeout, + WriteTimeout: setup.HTTPServerTimeout, + Handler: base.BaseMux, } logrus.Info("Listening on ", serv.Addr) diff --git a/cmd/dendrite-monolith-server/main_test.go b/cmd/dendrite-monolith-server/main_test.go new file mode 100644 index 000000000..efa1a926c --- /dev/null +++ b/cmd/dendrite-monolith-server/main_test.go @@ -0,0 +1,50 @@ +package main + +import ( + "os" + "os/signal" + "strings" + "syscall" + "testing" +) + +// This is an instrumented main, used when running integration tests (sytest) with code coverage. +// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server +// Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml +// Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html +// Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc +func TestMain(_ *testing.T) { + var ( + args []string + ) + + for _, arg := range os.Args { + switch { + case strings.HasPrefix(arg, "DEVEL"): + case strings.HasPrefix(arg, "-test"): + default: + args = append(args, arg) + } + } + // only run the tests if there are args to be passed + if len(args) <= 1 { + return + } + + waitCh := make(chan int, 1) + os.Args = args + go func() { + main() + close(waitCh) + }() + + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGHUP) + + select { + case <-signalCh: + return + case <-waitCh: + return + } +} diff --git a/cmd/dendrite-public-rooms-api-server/main.go b/cmd/dendrite-public-rooms-api-server/main.go index f6a782f66..23866b757 100644 --- a/cmd/dendrite-public-rooms-api-server/main.go +++ b/cmd/dendrite-public-rooms-api-server/main.go @@ -15,25 +15,26 @@ package main import ( - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/publicroomsapi" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/sirupsen/logrus" ) func main() { - cfg := basecomponent.ParseFlags() - base := basecomponent.NewBaseDendrite(cfg, "PublicRoomsAPI") + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "PublicRoomsAPI", true) defer base.Close() // nolint: errcheck - deviceDB := base.CreateDeviceDB() + userAPI := base.UserAPIClient() - _, _, query := base.CreateHTTPRoomserverAPIs() - publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI)) + rsAPI := base.RoomserverHTTPClient() + + publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), base.Cfg.DbProperties(), cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } - publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, publicRoomsDB, query, nil, nil) + publicroomsapi.AddPublicRoutes(base.PublicAPIMux, base.Cfg, base.KafkaConsumer, userAPI, publicRoomsDB, rsAPI, nil, nil) base.SetupAndServeHTTP(string(base.Cfg.Bind.PublicRoomsAPI), string(base.Cfg.Listen.PublicRoomsAPI)) diff --git a/cmd/dendrite-room-server/main.go b/cmd/dendrite-room-server/main.go index 41b705755..627a68677 100644 --- a/cmd/dendrite-room-server/main.go +++ b/cmd/dendrite-room-server/main.go @@ -15,18 +15,23 @@ package main import ( - _ "net/http/pprof" - - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/roomserver" ) func main() { - cfg := basecomponent.ParseFlags() - base := basecomponent.NewBaseDendrite(cfg, "RoomServerAPI") + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "RoomServerAPI", true) defer base.Close() // nolint: errcheck + federation := base.CreateFederationClient() - roomserver.SetupRoomServerComponent(base) + serverKeyAPI := base.ServerKeyAPIClient() + keyRing := serverKeyAPI.KeyRing() + + fsAPI := base.FederationSenderHTTPClient() + rsAPI := roomserver.NewInternalAPI(base, keyRing, federation) + rsAPI.SetFederationSenderAPI(fsAPI) + roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI) base.SetupAndServeHTTP(string(base.Cfg.Bind.RoomServer), string(base.Cfg.Listen.RoomServer)) diff --git a/cmd/dendrite-server-key-api-server/main.go b/cmd/dendrite-server-key-api-server/main.go new file mode 100644 index 000000000..9ffaeee31 --- /dev/null +++ b/cmd/dendrite-server-key-api-server/main.go @@ -0,0 +1,33 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/dendrite/serverkeyapi" +) + +func main() { + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "ServerKeyAPI", true) + defer base.Close() // nolint: errcheck + + federation := base.CreateFederationClient() + + intAPI := serverkeyapi.NewInternalAPI(base.Cfg, federation, base.Caches) + serverkeyapi.AddInternalRoutes(base.InternalAPIMux, intAPI, base.Caches) + + base.SetupAndServeHTTP(string(base.Cfg.Bind.ServerKeyAPI), string(base.Cfg.Listen.ServerKeyAPI)) +} diff --git a/cmd/dendrite-sync-api-server/main.go b/cmd/dendrite-sync-api-server/main.go index 55e9faeef..d67395fb3 100644 --- a/cmd/dendrite-sync-api-server/main.go +++ b/cmd/dendrite-sync-api-server/main.go @@ -15,22 +15,21 @@ package main import ( - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/syncapi" ) func main() { - cfg := basecomponent.ParseFlags() - base := basecomponent.NewBaseDendrite(cfg, "SyncAPI") + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "SyncAPI", true) defer base.Close() // nolint: errcheck - deviceDB := base.CreateDeviceDB() - accountDB := base.CreateAccountsDB() + userAPI := base.UserAPIClient() federation := base.CreateFederationClient() - _, _, query := base.CreateHTTPRoomserverAPIs() + rsAPI := base.RoomserverHTTPClient() - syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg) + syncapi.AddPublicRoutes(base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, federation, cfg) base.SetupAndServeHTTP(string(base.Cfg.Bind.SyncAPI), string(base.Cfg.Listen.SyncAPI)) diff --git a/clientapi/auth/authtypes/device.go b/cmd/dendrite-user-api-server/main.go similarity index 50% rename from clientapi/auth/authtypes/device.go rename to cmd/dendrite-user-api-server/main.go index 299eff036..4257da3f3 100644 --- a/clientapi/auth/authtypes/device.go +++ b/cmd/dendrite-user-api-server/main.go @@ -12,19 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -package authtypes +package main -// Device represents a client's device (mobile, web, etc) -type Device struct { - ID string - UserID string - // The access_token granted to this device. - // This uniquely identifies the device from all other devices and clients. - AccessToken string - // The unique ID of the session identified by the access token. - // Can be used as a secure substitution in places where data needs to be - // associated with access tokens. - SessionID int64 - // TODO: display name, last used timestamp, keys, etc - DisplayName string +import ( + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/dendrite/userapi" +) + +func main() { + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "UserAPI", true) + defer base.Close() // nolint: errcheck + + accountDB := base.CreateAccountsDB() + deviceDB := base.CreateDeviceDB() + + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices) + + userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) + + base.SetupAndServeHTTP(string(base.Cfg.Bind.UserAPI), string(base.Cfg.Listen.UserAPI)) } diff --git a/cmd/dendritejs/jsServer.go b/cmd/dendritejs/jsServer.go index a5ac574d8..074d20cba 100644 --- a/cmd/dendritejs/jsServer.go +++ b/cmd/dendritejs/jsServer.go @@ -28,7 +28,7 @@ import ( // JSServer exposes an HTTP-like server interface which allows JS to 'send' requests to it. type JSServer struct { // The router which will service requests - Mux *http.ServeMux + Mux http.Handler } // OnRequestFromJS is the function that JS will invoke when there is a new request. @@ -49,15 +49,11 @@ func (h *JSServer) OnRequestFromJS(this js.Value, args []js.Value) interface{} { // we need to put this in an immediately invoked goroutine. go func() { resolve := pargs[0] - fmt.Println("Received request:") - fmt.Printf("%s\n", httpStr) resStr, err := h.handle(httpStr) errStr := "" if err != nil { errStr = err.Error() } - fmt.Println("Sending response:") - fmt.Printf("%s\n", resStr) resolve.Invoke(map[string]interface{}{ "result": resStr, "error": errStr, diff --git a/cmd/dendritejs/keyfetcher.go b/cmd/dendritejs/keyfetcher.go index ee4905d4f..cef045372 100644 --- a/cmd/dendritejs/keyfetcher.go +++ b/cmd/dendritejs/keyfetcher.go @@ -23,7 +23,6 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) const libp2pMatrixKeyID = "ed25519:libp2p-dendrite" @@ -63,9 +62,7 @@ func (f *libp2pKeyFetcher) FetchKeys( if err != nil { return nil, fmt.Errorf("Failed to extract raw bytes from public key: %w", err) } - util.GetLogger(ctx).Info("libp2pKeyFetcher.FetchKeys: Using public key %v for server name %s", pubKeyBytes, req.ServerName) - - b64Key := gomatrixserverlib.Base64String(pubKeyBytes) + b64Key := gomatrixserverlib.Base64Bytes(pubKeyBytes) res[req] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: gomatrixserverlib.VerifyKey{ Key: b64Key, @@ -82,3 +79,8 @@ func (f *libp2pKeyFetcher) FetchKeys( func (f *libp2pKeyFetcher) FetcherName() string { return "libp2pKeyFetcher" } + +// no-op function for storing keys - we don't do any work to fetch them so don't bother storing. +func (f *libp2pKeyFetcher) StoreKeys(ctx context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error { + return nil +} diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 9bd8f2ee2..883b0fad0 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -19,25 +19,20 @@ package main import ( "crypto/ed25519" "fmt" - "net/http" + "syscall/js" "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/clientapi" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/basecomponent" - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/dendrite/common/transactions" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationsender" - "github.com/matrix-org/dendrite/mediaapi" - "github.com/matrix-org/dendrite/publicroomsapi" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/syncapi" - "github.com/matrix-org/go-http-js-libp2p/go_http_js_libp2p" + "github.com/matrix-org/dendrite/userapi" + go_http_js_libp2p "github.com/matrix-org/go-http-js-libp2p" + "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -45,15 +40,95 @@ import ( _ "github.com/matrix-org/go-sqlite3-js" ) +var GitCommit string + func init() { - fmt.Println("dendrite.js starting...") + fmt.Printf("[%s] dendrite.js starting...\n", GitCommit) +} + +const keyNameEd25519 = "_go_ed25519_key" + +func readKeyFromLocalStorage() (key ed25519.PrivateKey, err error) { + localforage := js.Global().Get("localforage") + if !localforage.Truthy() { + err = fmt.Errorf("readKeyFromLocalStorage: no localforage") + return + } + // https://localforage.github.io/localForage/ + item, ok := await(localforage.Call("getItem", keyNameEd25519)) + if !ok || !item.Truthy() { + err = fmt.Errorf("readKeyFromLocalStorage: no key in localforage") + return + } + fmt.Println("Found key in localforage") + // extract []byte and make an ed25519 key + seed := make([]byte, 32, 32) + js.CopyBytesToGo(seed, item) + + return ed25519.NewKeyFromSeed(seed), nil +} + +func writeKeyToLocalStorage(key ed25519.PrivateKey) error { + localforage := js.Global().Get("localforage") + if !localforage.Truthy() { + return fmt.Errorf("writeKeyToLocalStorage: no localforage") + } + + // make a Uint8Array from the key's seed + seed := key.Seed() + jsSeed := js.Global().Get("Uint8Array").New(len(seed)) + js.CopyBytesToJS(jsSeed, seed) + // write it + localforage.Call("setItem", keyNameEd25519, jsSeed) + return nil +} + +// taken from https://go-review.googlesource.com/c/go/+/150917 + +// await waits until the promise v has been resolved or rejected and returns the promise's result value. +// The boolean value ok is true if the promise has been resolved, false if it has been rejected. +// If v is not a promise, v itself is returned as the value and ok is true. +func await(v js.Value) (result js.Value, ok bool) { + if v.Type() != js.TypeObject || v.Get("then").Type() != js.TypeFunction { + return v, true + } + done := make(chan struct{}) + onResolve := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + result = args[0] + ok = true + close(done) + return nil + }) + defer onResolve.Release() + onReject := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + result = args[0] + ok = false + close(done) + return nil + }) + defer onReject.Release() + v.Call("then", onResolve, onReject) + <-done + return } func generateKey() ed25519.PrivateKey { - _, priv, err := ed25519.GenerateKey(nil) + // attempt to look for a seed in JS-land and if it exists use it. + priv, err := readKeyFromLocalStorage() + if err == nil { + fmt.Println("Read key from localStorage") + return priv + } + // generate a new key + fmt.Println(err, " : Generating new ed25519 key") + _, priv, err = ed25519.GenerateKey(nil) if err != nil { logrus.Fatalf("Failed to generate ed25519 key: %s", err) } + if err := writeKeyToLocalStorage(priv); err != nil { + fmt.Println("failed to write key to localStorage: ", err) + // non-fatal, we'll just have amnesia for a while + } return priv } @@ -70,9 +145,14 @@ func createFederationClient(cfg *config.Dendrite, node *go_http_js_libp2p.P2pLoc return fed } +func createClient(node *go_http_js_libp2p.P2pLocalNode) *gomatrixserverlib.Client { + tr := go_http_js_libp2p.NewP2pTransport(node) + return gomatrixserverlib.NewClientWithTransport(tr) +} + func createP2PNode(privKey ed25519.PrivateKey) (serverName string, node *go_http_js_libp2p.P2pLocalNode) { hosted := "/dns4/rendezvous.matrix.org/tcp/8443/wss/p2p-websocket-star/" - node = go_http_js_libp2p.NewP2pLocalNode("org.matrix.p2p.experiment", privKey.Seed(), []string{hosted}) + node = go_http_js_libp2p.NewP2pLocalNode("org.matrix.p2p.experiment", privKey.Seed(), []string{hosted}, "p2p") serverName = node.Id fmt.Println("p2p assigned ServerName: ", serverName) return @@ -82,18 +162,18 @@ func main() { cfg := &config.Dendrite{} cfg.SetDefaults() cfg.Kafka.UseNaffka = true - cfg.Database.Account = "file:dendritejs_account.db" - cfg.Database.AppService = "file:dendritejs_appservice.db" - cfg.Database.Device = "file:dendritejs_device.db" - cfg.Database.FederationSender = "file:dendritejs_fedsender.db" - cfg.Database.MediaAPI = "file:dendritejs_mediaapi.db" - cfg.Database.Naffka = "file:dendritejs_naffka.db" - cfg.Database.PublicRoomsAPI = "file:dendritejs_publicrooms.db" - cfg.Database.RoomServer = "file:dendritejs_roomserver.db" - cfg.Database.ServerKey = "file:dendritejs_serverkey.db" - cfg.Database.SyncAPI = "file:dendritejs_syncapi.db" - cfg.Kafka.Topics.UserUpdates = "user_updates" + cfg.Database.Account = "file:/idb/dendritejs_account.db" + cfg.Database.AppService = "file:/idb/dendritejs_appservice.db" + cfg.Database.Device = "file:/idb/dendritejs_device.db" + cfg.Database.FederationSender = "file:/idb/dendritejs_fedsender.db" + cfg.Database.MediaAPI = "file:/idb/dendritejs_mediaapi.db" + cfg.Database.Naffka = "file:/idb/dendritejs_naffka.db" + cfg.Database.PublicRoomsAPI = "file:/idb/dendritejs_publicrooms.db" + cfg.Database.RoomServer = "file:/idb/dendritejs_roomserver.db" + cfg.Database.ServerKey = "file:/idb/dendritejs_serverkey.db" + cfg.Database.SyncAPI = "file:/idb/dendritejs_syncapi.db" cfg.Kafka.Topics.OutputTypingEvent = "output_typing_event" + cfg.Kafka.Topics.OutputSendToDeviceEvent = "output_send_to_device_event" cfg.Kafka.Topics.OutputClientData = "output_client_data" cfg.Kafka.Topics.OutputRoomEvent = "output_room_event" cfg.Matrix.TrustedIDServers = []string{ @@ -108,53 +188,72 @@ func main() { if err := cfg.Derive(); err != nil { logrus.Fatalf("Failed to derive values from config: %s", err) } - base := basecomponent.NewBaseDendrite(cfg, "Monolith") + base := setup.NewBaseDendrite(cfg, "Monolith", false) defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() - keyDB := base.CreateKeyDB() federation := createFederationClient(cfg, node) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil) + + fetcher := &libp2pKeyFetcher{} keyRing := gomatrixserverlib.KeyRing{ KeyFetchers: []gomatrixserverlib.KeyFetcher{ - &libp2pKeyFetcher{}, + fetcher, }, - KeyDatabase: keyDB, + KeyDatabase: fetcher, } - p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node) - alias, input, query := roomserver.SetupRoomServerComponent(base) - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) - asQuery := appservice.SetupAppServiceAPIComponent( - base, accountDB, deviceDB, federation, alias, query, transactions.New(), + rsAPI := roomserver.NewInternalAPI(base, keyRing, federation) + eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) + asQuery := appservice.NewInternalAPI( + base, userAPI, rsAPI, ) - fedSenderAPI := federationsender.SetupFederationSenderComponent(base, federation, query) + fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, &keyRing) + rsAPI.SetFederationSenderAPI(fedSenderAPI) + p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI) - clientapi.SetupClientAPIComponent( - base, deviceDB, accountDB, - federation, &keyRing, alias, input, query, - eduInputAPI, asQuery, transactions.New(), fedSenderAPI, - ) - eduProducer := producers.NewEDUServerProducer(eduInputAPI) - federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI, eduProducer) - mediaapi.SetupMediaAPIComponent(base, deviceDB) - publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI)) + publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } - publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, publicRoomsDB, query, federation, p2pPublicRoomProvider) - syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg) - httpHandler := common.WrapHandlerInCORS(base.APIMux) + monolith := setup.Monolith{ + Config: base.Cfg, + AccountDB: accountDB, + DeviceDB: deviceDB, + Client: createClient(node), + FedClient: federation, + KeyRing: &keyRing, + KafkaConsumer: base.KafkaConsumer, + KafkaProducer: base.KafkaProducer, - http.Handle("/", httpHandler) + AppserviceAPI: asQuery, + EDUInternalAPI: eduInputAPI, + FederationSenderAPI: fedSenderAPI, + RoomserverAPI: rsAPI, + UserAPI: userAPI, + //ServerKeyAPI: serverKeyAPI, + + PublicRoomsDB: publicRoomsDB, + ExtPublicRoomsProvider: p2pPublicRoomProvider, + } + monolith.AddAllPublicRoutes(base.PublicAPIMux) + + httputil.SetupHTTPAPI( + base.BaseMux, + base.PublicAPIMux, + base.InternalAPIMux, + cfg, + base.UseHTTPAPIs, + ) // Expose the matrix APIs via libp2p-js - for federation traffic if node != nil { go func() { logrus.Info("Listening on libp2p-js host ID ", node.Id) s := JSServer{ - Mux: http.DefaultServeMux, + Mux: base.BaseMux, } s.ListenAndServe("p2p") }() @@ -164,7 +263,7 @@ func main() { go func() { logrus.Info("Listening for service-worker fetch traffic") s := JSServer{ - Mux: http.DefaultServeMux, + Mux: base.BaseMux, } s.ListenAndServe("fetch") }() diff --git a/cmd/dendritejs/publicrooms.go b/cmd/dendritejs/publicrooms.go index 17822e7ad..5032bc15f 100644 --- a/cmd/dendritejs/publicrooms.go +++ b/cmd/dendritejs/publicrooms.go @@ -17,23 +17,48 @@ package main import ( - "github.com/matrix-org/go-http-js-libp2p/go_http_js_libp2p" + "context" + + "github.com/matrix-org/dendrite/federationsender/api" + go_http_js_libp2p "github.com/matrix-org/go-http-js-libp2p" + "github.com/matrix-org/gomatrixserverlib" ) type libp2pPublicRoomsProvider struct { node *go_http_js_libp2p.P2pLocalNode providers []go_http_js_libp2p.PeerInfo + fedSender api.FederationSenderInternalAPI } -func NewLibP2PPublicRoomsProvider(node *go_http_js_libp2p.P2pLocalNode) *libp2pPublicRoomsProvider { +func NewLibP2PPublicRoomsProvider(node *go_http_js_libp2p.P2pLocalNode, fedSender api.FederationSenderInternalAPI) *libp2pPublicRoomsProvider { p := &libp2pPublicRoomsProvider{ - node: node, + node: node, + fedSender: fedSender, } node.RegisterFoundProviders(p.foundProviders) return p } func (p *libp2pPublicRoomsProvider) foundProviders(peerInfos []go_http_js_libp2p.PeerInfo) { + // work out the diff then poke for new ones + seen := make(map[string]bool, len(p.providers)) + for _, pr := range p.providers { + seen[pr.Id] = true + } + var newPeers []gomatrixserverlib.ServerName + for _, pi := range peerInfos { + if !seen[pi.Id] { + newPeers = append(newPeers, gomatrixserverlib.ServerName(pi.Id)) + } + } + if len(newPeers) > 0 { + var res api.PerformServersAliveResponse + // ignore errors, we don't care. + p.fedSender.PerformServersAlive(context.Background(), &api.PerformServersAliveRequest{ + Servers: newPeers, + }, &res) + } + p.providers = peerInfos } diff --git a/cmd/federation-api-proxy/main.go b/cmd/federation-api-proxy/main.go index fa90482d5..7324de148 100644 --- a/cmd/federation-api-proxy/main.go +++ b/cmd/federation-api-proxy/main.go @@ -73,7 +73,6 @@ func makeProxy(targetURL string) (*httputil.ReverseProxy, error) { // Pratically this means that any distinction between '%2F' and '/' // in the URL will be lost by the time it reaches the target. path := req.URL.Path - path = "api" + path log.WithFields(log.Fields{ "path": path, "url": targetURL, diff --git a/cmd/generate-keys/main.go b/cmd/generate-keys/main.go index b807c2673..ceeca5672 100644 --- a/cmd/generate-keys/main.go +++ b/cmd/generate-keys/main.go @@ -20,7 +20,7 @@ import ( "log" "os" - "github.com/matrix-org/dendrite/common/test" + "github.com/matrix-org/dendrite/internal/test" ) const usage = `Usage: %s diff --git a/cmd/kafka-producer/main.go b/cmd/kafka-producer/main.go index f5f243e4e..18ee3cdf2 100644 --- a/cmd/kafka-producer/main.go +++ b/cmd/kafka-producer/main.go @@ -21,7 +21,7 @@ import ( "os" "strings" - sarama "gopkg.in/Shopify/sarama.v1" + sarama "github.com/Shopify/sarama" ) const usage = `Usage: %s diff --git a/cmd/mediaapi-integration-tests/main.go b/cmd/mediaapi-integration-tests/main.go index fd6982fd3..e6ce14d28 100644 --- a/cmd/mediaapi-integration-tests/main.go +++ b/cmd/mediaapi-integration-tests/main.go @@ -24,7 +24,7 @@ import ( "path/filepath" "time" - "github.com/matrix-org/dendrite/common/test" + "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/gomatrixserverlib" "gopkg.in/yaml.v2" ) diff --git a/cmd/roomserver-integration-tests/main.go b/cmd/roomserver-integration-tests/main.go index 682fc6224..3860ca1f7 100644 --- a/cmd/roomserver-integration-tests/main.go +++ b/cmd/roomserver-integration-tests/main.go @@ -28,9 +28,10 @@ import ( "net/http" - "github.com/matrix-org/dendrite/common/caching" - "github.com/matrix-org/dendrite/common/test" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/gomatrixserverlib" ) @@ -209,7 +210,7 @@ func writeToRoomServer(input []string, roomserverURL string) error { return err } } - x, err := api.NewRoomserverInputAPIHTTP(roomserverURL, &http.Client{Timeout: timeoutHTTP}) + x, err := inthttp.NewRoomserverClient(roomserverURL, &http.Client{Timeout: timeoutHTTP}, nil) if err != nil { return err } @@ -225,7 +226,7 @@ func writeToRoomServer(input []string, roomserverURL string) error { // Once those messages have been written it runs the checkQueries function passing // a api.RoomserverQueryAPI client. The caller can use this function to check the // behaviour of the query API. -func testRoomserver(input []string, wantOutput []string, checkQueries func(api.RoomserverQueryAPI)) { +func testRoomserver(input []string, wantOutput []string, checkQueries func(api.RoomserverInternalAPI)) { dir, err := ioutil.TempDir("", "room-server-test") if err != nil { panic(err) @@ -254,7 +255,7 @@ func testRoomserver(input []string, wantOutput []string, checkQueries func(api.R panic(err) } - cache, err := caching.NewImmutableInMemoryLRUCache() + cache, err := caching.NewInMemoryLRUCache(false) if err != nil { panic(err) } @@ -276,7 +277,7 @@ func testRoomserver(input []string, wantOutput []string, checkQueries func(api.R cmd.Args = []string{"dendrite-room-server", "--config", filepath.Join(dir, test.ConfigFile)} gotOutput, err := runAndReadFromTopic(cmd, cfg.RoomServerURL()+"/metrics", doInput, outputTopic, len(wantOutput), func() { - queryAPI, _ := api.NewRoomserverQueryAPIHTTP("http://"+string(cfg.Listen.RoomServer), &http.Client{Timeout: timeoutHTTP}, cache) + queryAPI, _ := inthttp.NewRoomserverClient("http://"+string(cfg.Listen.RoomServer), &http.Client{Timeout: timeoutHTTP}, cache) checkQueries(queryAPI) }) if err != nil { @@ -410,7 +411,7 @@ func main() { }}`, } - testRoomserver(input, want, func(q api.RoomserverQueryAPI) { + testRoomserver(input, want, func(q api.RoomserverInternalAPI) { var response api.QueryLatestEventsAndStateResponse if err := q.QueryLatestEventsAndState( context.Background(), diff --git a/cmd/syncserver-integration-tests/main.go b/cmd/syncserver-integration-tests/main.go index d14e854c3..cfe8cc165 100644 --- a/cmd/syncserver-integration-tests/main.go +++ b/cmd/syncserver-integration-tests/main.go @@ -24,8 +24,8 @@ import ( "path/filepath" "time" - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/dendrite/common/test" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/common/caching/immutablecache.go b/common/caching/immutablecache.go deleted file mode 100644 index 9620667a2..000000000 --- a/common/caching/immutablecache.go +++ /dev/null @@ -1,12 +0,0 @@ -package caching - -import "github.com/matrix-org/gomatrixserverlib" - -const ( - RoomVersionMaxCacheEntries = 128 -) - -type ImmutableCache interface { - GetRoomVersion(roomId string) (gomatrixserverlib.RoomVersion, bool) - StoreRoomVersion(roomId string, roomVersion gomatrixserverlib.RoomVersion) -} diff --git a/common/caching/immutableinmemorylru.go b/common/caching/immutableinmemorylru.go deleted file mode 100644 index 3e8f4aadb..000000000 --- a/common/caching/immutableinmemorylru.go +++ /dev/null @@ -1,43 +0,0 @@ -package caching - -import ( - "fmt" - - lru "github.com/hashicorp/golang-lru" - "github.com/matrix-org/gomatrixserverlib" -) - -type ImmutableInMemoryLRUCache struct { - roomVersions *lru.Cache -} - -func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) { - roomVersionCache, rvErr := lru.New(RoomVersionMaxCacheEntries) - if rvErr != nil { - return nil, rvErr - } - return &ImmutableInMemoryLRUCache{ - roomVersions: roomVersionCache, - }, nil -} - -func checkForInvalidMutation(cache *lru.Cache, key string, value interface{}) { - if peek, ok := cache.Peek(key); ok && peek != value { - panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) - } -} - -func (c *ImmutableInMemoryLRUCache) GetRoomVersion(roomID string) (gomatrixserverlib.RoomVersion, bool) { - val, found := c.roomVersions.Get(roomID) - if found && val != nil { - if roomVersion, ok := val.(gomatrixserverlib.RoomVersion); ok { - return roomVersion, true - } - } - return "", false -} - -func (c *ImmutableInMemoryLRUCache) StoreRoomVersion(roomID string, roomVersion gomatrixserverlib.RoomVersion) { - checkForInvalidMutation(c.roomVersions, roomID, roomVersion) - c.roomVersions.Add(roomID, roomVersion) -} diff --git a/common/keydb/keyring.go b/common/keydb/keyring.go deleted file mode 100644 index e9cc7903e..000000000 --- a/common/keydb/keyring.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2017 New Vector Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package keydb - -import ( - "encoding/base64" - - "github.com/matrix-org/dendrite/common/config" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - "golang.org/x/crypto/ed25519" -) - -// CreateKeyRing creates and configures a KeyRing object. -// -// It creates the necessary key fetchers and collects them into a KeyRing -// backed by the given KeyDatabase. -func CreateKeyRing(client gomatrixserverlib.Client, - keyDB gomatrixserverlib.KeyDatabase, - cfg config.KeyPerspectives) gomatrixserverlib.KeyRing { - - fetchers := gomatrixserverlib.KeyRing{ - KeyFetchers: []gomatrixserverlib.KeyFetcher{ - &gomatrixserverlib.DirectKeyFetcher{ - Client: client, - }, - }, - KeyDatabase: keyDB, - } - - logrus.Info("Enabled direct key fetcher") - - var b64e = base64.StdEncoding.WithPadding(base64.NoPadding) - for _, ps := range cfg { - perspective := &gomatrixserverlib.PerspectiveKeyFetcher{ - PerspectiveServerName: ps.ServerName, - PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{}, - Client: client, - } - - for _, key := range ps.Keys { - rawkey, err := b64e.DecodeString(key.PublicKey) - if err != nil { - logrus.WithError(err).WithFields(logrus.Fields{ - "server_name": ps.ServerName, - "public_key": key.PublicKey, - }).Warn("Couldn't parse perspective key") - continue - } - perspective.PerspectiveServerKeys[key.KeyID] = rawkey - } - - fetchers.KeyFetchers = append(fetchers.KeyFetchers, perspective) - - logrus.WithFields(logrus.Fields{ - "server_name": ps.ServerName, - "num_public_keys": len(ps.Keys), - }).Info("Enabled perspective key fetcher") - } - - return fetchers -} diff --git a/dendrite-config.yaml b/dendrite-config.yaml index bed78a5af..73bfec247 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -25,6 +25,8 @@ matrix: # public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw # - key_id: ed25519:a_RXGa # public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ + # Disables new users from registering (except via shared secrets) + registration_disabled: false # The media repository config media: @@ -102,10 +104,13 @@ kafka: topics: output_room_event: roomserverOutput output_client_data: clientapiOutput - output_typing_event: eduServerOutput + output_typing_event: eduServerTypingOutput + output_send_to_device_event: eduServerSendToDeviceOutput user_updates: userUpdates -# The postgres connection configs for connecting to the databases e.g a postgres:// URI +# The postgres connection configs for connecting to the databases, e.g. +# for Postgres: postgres://username:password@hostname/database +# for SQLite: file:filename.db or file:///path/to/filename.db database: account: "postgres://dendrite:itsasecret@localhost/dendrite_account?sslmode=disable" device: "postgres://dendrite:itsasecret@localhost/dendrite_device?sslmode=disable" @@ -116,7 +121,10 @@ database: federation_sender: "postgres://dendrite:itsasecret@localhost/dendrite_federationsender?sslmode=disable" appservice: "postgres://dendrite:itsasecret@localhost/dendrite_appservice?sslmode=disable" public_rooms_api: "postgres://dendrite:itsasecret@localhost/dendrite_publicroomsapi?sslmode=disable" - # If using naffka you need to specify a naffka database + max_open_conns: 100 + max_idle_conns: 2 + conn_max_lifetime: -1 + # If 'use_naffka: true' set above then you need to specify a naffka database # naffka: "postgres://dendrite:itsasecret@localhost/dendrite_naffka?sslmode=disable" # The TCP host:port pairs to bind the internal HTTP APIs to. @@ -132,6 +140,9 @@ listen: federation_sender: "localhost:7776" appservice_api: "localhost:7777" edu_server: "localhost:7778" + key_server: "localhost:7779" + server_key_api: "localhost:7780" + user_api: "localhost:7781" # The configuration for tracing the dendrite components. tracing: diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index 5810825a4..000000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -FROM docker.io/golang:1.13.7-alpine3.11 - -RUN mkdir /build - -WORKDIR /build - -RUN apk --update --no-cache add openssl bash git build-base - -CMD ["bash", "docker/build.sh"] diff --git a/docker/README.md b/docker/README.md deleted file mode 100644 index 83d0b6a87..000000000 --- a/docker/README.md +++ /dev/null @@ -1,100 +0,0 @@ -Development with Docker ---- - -With `docker` and `docker-compose` you can easily spin up a development environment -and start working on dendrite. - -### Requirements - -- docker -- docker-compose (version 3+) - -### Configuration - -Create a directory named `cfg` in the root of the project. Copy the -`dendrite-docker.yaml` file into that directory and rename it to `dendrite.yaml`. -It already contains the defaults used in `docker-compose` for networking so you will -only have to change things like the `server_name` or to toggle `naffka`. - -You can run the following `docker-compose` commands either from the top directory -specifying the `docker-compose` file -``` -docker-compose -f docker/docker-compose.yml -``` -or from within the `docker` directory - -``` -docker-compose -``` - -### Starting a monolith server - -For the monolith server you would need a postgres instance - -``` -docker-compose up postgres -``` - -and the dendrite component from `bin/dendrite-monolith-server` - -``` -docker-compose up monolith -``` - -The monolith will be listening on `http://localhost:8008`. - -You would also have to make the following adjustments to `dendrite.yaml`. - - Set `use_naffka: true` - - Uncomment the `database/naffka` postgres url. - -### Starting a multiprocess server - -The multiprocess server requires kafka, zookeeper and postgres - -``` -docker-compose up kafka zookeeper postgres -``` - -and the following dendrite components - -``` -docker-compose up client_api media_api sync_api room_server public_rooms_api edu_server -docker-compose up client_api_proxy -``` - -The `client-api-proxy` will be listening on `http://localhost:8008`. - -You would also have to make the following adjustments to `dendrite.yaml`. - - Set `use_naffka: false` - - Comment out the `database/naffka` postgres url. - -### Starting federation - -``` -docker-compose up federation_api federation_sender -docker-compose up federation_api_proxy -``` - -You can point other Matrix servers to `http://localhost:8448`. - -### Creating a new component - -You can create a new dendrite component by adding an entry to the `docker-compose.yml` -file and creating a startup script for the component in `docker/services`. -For more information refer to the official docker-compose [documentation](https://docs.docker.com/compose/). - -```yaml - new_component: - container_name: dendrite_room_server - hostname: new_component - # Start up script. - entrypoint: ["bash", "./docker/services/new-component.sh"] - # Use the common Dockerfile for all the dendrite components. - build: ./ - volumes: - - ..:/build - depends_on: - - another_component - networks: - - internal -``` diff --git a/docker/build.sh b/docker/build.sh deleted file mode 100644 index a3e3ca24d..000000000 --- a/docker/build.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -./build.sh - -# Generate the keys if they don't already exist. -if [ ! -f server.key ] || [ ! -f server.crt ] || [ ! -f matrix_key.pem ]; then - echo "Generating keys ..." - - rm -f server.key server.crt matrix_key.pem - - test -f server.key || openssl req -x509 -newkey rsa:4096 \ - -keyout server.key \ - -out server.crt \ - -days 3650 -nodes \ - -subj /CN=localhost - - test -f matrix_key.pem || /build/bin/generate-keys -private-key matrix_key.pem -fi diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml deleted file mode 100644 index 957c3bf3f..000000000 --- a/docker/docker-compose.yml +++ /dev/null @@ -1,192 +0,0 @@ -version: "3.4" -services: - riot: - image: vectorim/riot-web - networks: - - internal - ports: - - "8500:80" - - monolith: - container_name: dendrite_monolith - hostname: monolith - entrypoint: ["bash", "./docker/services/monolith.sh", "--config", "/etc/dendrite/dendrite.yaml"] - build: ./ - volumes: - - ..:/build - - ./build/bin:/build/bin - - ../cfg:/etc/dendrite - networks: - - internal - depends_on: - - postgres - ports: - - "8008:8008" - - "8448:8448" - - client_api_proxy: - container_name: dendrite_client_api_proxy - hostname: client_api_proxy - entrypoint: ["bash", "./docker/services/client-api-proxy.sh"] - build: ./ - volumes: - - ..:/build - networks: - - internal - depends_on: - - postgres - - sync_api - - client_api - - media_api - - public_rooms_api - ports: - - "8008:8008" - - client_api: - container_name: dendrite_client_api - hostname: client_api - entrypoint: ["bash", "./docker/services/client-api.sh"] - build: ./ - volumes: - - ..:/build - depends_on: - - postgres - - room_server - networks: - - internal - - media_api: - container_name: dendrite_media_api - hostname: media_api - entrypoint: ["bash", "./docker/services/media-api.sh"] - build: ./ - volumes: - - ..:/build - depends_on: - - postgres - networks: - - internal - - public_rooms_api: - container_name: dendrite_public_rooms_api - hostname: public_rooms_api - entrypoint: ["bash", "./docker/services/public-rooms-api.sh"] - build: ./ - volumes: - - ..:/build - depends_on: - - postgres - networks: - - internal - - sync_api: - container_name: dendrite_sync_api - hostname: sync_api - entrypoint: ["bash", "./docker/services/sync-api.sh"] - build: ./ - volumes: - - ..:/build - depends_on: - - postgres - networks: - - internal - - room_server: - container_name: dendrite_room_server - hostname: room_server - entrypoint: ["bash", "./docker/services/room-server.sh"] - build: ./ - volumes: - - ..:/build - depends_on: - - postgres - networks: - - internal - - edu_server: - container_name: dendrite_edu_server - hostname: edu_server - entrypoint: ["bash", "./docker/services/edu-server.sh"] - build: ./ - volumes: - - ..:/build - networks: - - internal - - federation_api_proxy: - container_name: dendrite_federation_api_proxy - hostname: federation_api_proxy - entrypoint: ["bash", "./docker/services/federation-api-proxy.sh"] - build: ./ - volumes: - - ..:/build - depends_on: - - postgres - - federation_api - - federation_sender - - media_api - networks: - - internal - ports: - - "8448:8448" - - federation_api: - container_name: dendrite_federation_api - hostname: federation_api - entrypoint: ["bash", "./docker/services/federation-api.sh"] - build: ./ - volumes: - - ..:/build - depends_on: - - postgres - networks: - - internal - - federation_sender: - container_name: dendrite_federation_sender - hostname: federation_sender - entrypoint: ["bash", "./docker/services/federation-sender.sh"] - build: ./ - volumes: - - ..:/build - depends_on: - - postgres - networks: - - internal - - postgres: - container_name: dendrite_postgres - hostname: postgres - image: postgres:9.5 - restart: always - volumes: - - ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh - environment: - POSTGRES_PASSWORD: itsasecret - POSTGRES_USER: dendrite - networks: - - internal - - zookeeper: - container_name: dendrite_zk - hostname: zookeeper - image: zookeeper - networks: - - internal - - kafka: - container_name: dendrite_kafka - hostname: kafka - image: wurstmeister/kafka - environment: - KAFKA_ADVERTISED_HOST_NAME: "kafka" - KAFKA_DELETE_TOPIC_ENABLE: "true" - KAFKA_ZOOKEEPER_CONNECT: "zookeeper:2181" - depends_on: - - zookeeper - networks: - - internal - -networks: - internal: - attachable: true diff --git a/docker/services/client-api-proxy.sh b/docker/services/client-api-proxy.sh deleted file mode 100644 index 931f7abbc..000000000 --- a/docker/services/client-api-proxy.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/client-api-proxy --bind-address ":8008" \ - --client-api-server-url "http://client_api:7771" \ - --sync-api-server-url "http://sync_api:7773" \ - --media-api-server-url "http://media_api:7774" \ - --public-rooms-api-server-url "http://public_rooms_api:7775" \ diff --git a/docker/services/client-api.sh b/docker/services/client-api.sh deleted file mode 100644 index 8dc822421..000000000 --- a/docker/services/client-api.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/dendrite-client-api-server --config=dendrite.yaml diff --git a/docker/services/edu-server.sh b/docker/services/edu-server.sh deleted file mode 100644 index d40b9fa7e..000000000 --- a/docker/services/edu-server.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/dendrite-edu-server --config=dendrite.yaml diff --git a/docker/services/federation-api-proxy.sh b/docker/services/federation-api-proxy.sh deleted file mode 100644 index 6ea75c95a..000000000 --- a/docker/services/federation-api-proxy.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/federation-api-proxy --bind-address ":8448" \ - --federation-api-url "http://federation_api_server:7772" \ - --media-api-server-url "http://media_api:7774" \ diff --git a/docker/services/federation-api.sh b/docker/services/federation-api.sh deleted file mode 100644 index 807a7cf83..000000000 --- a/docker/services/federation-api.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/dendrite-federation-api-server --config dendrite.yaml diff --git a/docker/services/federation-sender.sh b/docker/services/federation-sender.sh deleted file mode 100644 index ea116ef3c..000000000 --- a/docker/services/federation-sender.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/dendrite-federation-sender-server --config dendrite.yaml diff --git a/docker/services/media-api.sh b/docker/services/media-api.sh deleted file mode 100644 index 876b3aa8d..000000000 --- a/docker/services/media-api.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/dendrite-media-api-server --config dendrite.yaml diff --git a/docker/services/monolith.sh b/docker/services/monolith.sh deleted file mode 100644 index 2287555cd..000000000 --- a/docker/services/monolith.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/dendrite-monolith-server --tls-cert=server.crt --tls-key=server.key $@ diff --git a/docker/services/public-rooms-api.sh b/docker/services/public-rooms-api.sh deleted file mode 100644 index 652afcfec..000000000 --- a/docker/services/public-rooms-api.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/dendrite-public-rooms-api-server --config dendrite.yaml diff --git a/docker/services/room-server.sh b/docker/services/room-server.sh deleted file mode 100644 index 473b5f5d3..000000000 --- a/docker/services/room-server.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/dendrite-room-server --config=dendrite.yaml diff --git a/docker/services/sync-api.sh b/docker/services/sync-api.sh deleted file mode 100644 index ac6433fa5..000000000 --- a/docker/services/sync-api.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -bash ./docker/build.sh - -./bin/dendrite-sync-api-server --config=dendrite.yaml diff --git a/CODE_STYLE.md b/docs/CODE_STYLE.md similarity index 100% rename from CODE_STYLE.md rename to docs/CODE_STYLE.md diff --git a/CONTRIBUTING.md b/docs/CONTRIBUTING.md similarity index 100% rename from CONTRIBUTING.md rename to docs/CONTRIBUTING.md diff --git a/DESIGN.md b/docs/DESIGN.md similarity index 100% rename from DESIGN.md rename to docs/DESIGN.md diff --git a/INSTALL.md b/docs/INSTALL.md similarity index 63% rename from INSTALL.md rename to docs/INSTALL.md index 4173e705e..b4c81a42b 100644 --- a/INSTALL.md +++ b/docs/INSTALL.md @@ -2,38 +2,66 @@ Dendrite can be run in one of two configurations: - * A cluster of individual components, dealing with different aspects of the - Matrix protocol (see [WIRING.md](./WIRING.md)). Components communicate with - one another via [Apache Kafka](https://kafka.apache.org). +* **Polylith mode**: A cluster of individual components, dealing with different + aspects of the Matrix protocol (see [WIRING.md](WIRING.md)). Components communicate with each other using internal HTTP APIs and [Apache Kafka](https://kafka.apache.org). This will almost certainly be the preferred model + for large-scale deployments. - * A monolith server, in which all components run in the same process. In this - configuration, Kafka can be replaced with an in-process implementation - called [naffka](https://github.com/matrix-org/naffka). +* **Monolith mode**: All components run in the same process. In this mode, + Kafka is completely optional and can instead be replaced with an in-process + lightweight implementation called [Naffka](https://github.com/matrix-org/naffka). This will usually be the preferred model for low-volume, low-user + or experimental deployments. + +Regardless of whether you are running in polylith or monolith mode, each Dendrite component that requires storage has its own database. Both Postgres +and SQLite are supported and can be mixed-and-matched across components as +needed in the configuration file. + +Be advised that Dendrite is still developmental and it's not recommended for +use in production environments yet! ## Requirements - - Go 1.13+ - - Postgres 9.5+ - - For Kafka (optional if using the monolith server): - - Unix-based system (https://kafka.apache.org/documentation/#os) - - JDK 1.8+ / OpenJDK 1.8+ - - Apache Kafka 0.10.2+ (see [scripts/install-local-kafka.sh](scripts/install-local-kafka.sh) for up-to-date version numbers) +Dendrite requires: +* Go 1.13 or higher +* Postgres 9.5 or higher (if using Postgres databases, not needed for SQLite) -## Setting up a development environment +If you want to run a polylith deployment, you also need: -Assumes Go 1.13+ and JDK 1.8+ are already installed and are on PATH. +* Apache Kafka 0.10.2+ + +## Building up a monolith deploment + +Start by cloning the code: ```bash -# Get the code git clone https://github.com/matrix-org/dendrite cd dendrite +``` -# Build it +Then build it: + +```bash +go build -o bin/dendrite-monolith-server ./cmd/dendrite-monolith-server +go build -o bin/generate-keys ./cmd/generate-keys +``` + +## Building up a polylith deployment + +Start by cloning the code: + +```bash +git clone https://github.com/matrix-org/dendrite +cd dendrite +``` + +Then build it: + +```bash ./build.sh ``` -If using Kafka, install and start it (c.f. [scripts/install-local-kafka.sh](scripts/install-local-kafka.sh)): +Install and start Kafka (c.f. [scripts/install-local-kafka.sh](scripts/install-local-kafka.sh)): + ```bash KAFKA_URL=http://archive.apache.org/dist/kafka/2.1.0/kafka_2.11-2.1.0.tgz @@ -51,7 +79,7 @@ kafka/bin/zookeeper-server-start.sh -daemon kafka/config/zookeeper.properties kafka/bin/kafka-server-start.sh -daemon kafka/config/server.properties ``` -On MacOS, you can use [homebrew](https://brew.sh/) for easier setup of kafka +On macOS, you can use [Homebrew](https://brew.sh/) for easier setup of Kafka: ```bash brew install kafka @@ -61,15 +89,24 @@ brew services start kafka ## Configuration +### SQLite database setup + +Dendrite can use the built-in SQLite database engine for small setups. +The SQLite databases do not need to be preconfigured - Dendrite will +create them automatically at startup. + ### Postgres database setup -Dendrite requires a postgres database engine, version 9.5 or later. +Assuming that Postgres 9.5 (or later) is installed: + +* Create role, choosing a new password when prompted: -* Create role: ```bash - sudo -u postgres createuser -P dendrite # prompts for password + sudo -u postgres createuser -P dendrite ``` -* Create databases: + +* Create the component databases: + ```bash for i in account device mediaapi syncapi roomserver serverkey federationsender publicroomsapi appservice naffka; do sudo -u postgres createdb -O dendrite dendrite_$i @@ -78,42 +115,56 @@ Dendrite requires a postgres database engine, version 9.5 or later. (On macOS, omit `sudo -u postgres` from the above commands.) -### Crypto key generation +### Server key generation -Generate the keys: +Each Dendrite server requires unique server keys. + +Generate the self-signed SSL certificate for federation: ```bash -# Generate a self-signed SSL cert for federation: test -f server.key || openssl req -x509 -newkey rsa:4096 -keyout server.key -out server.crt -days 3650 -nodes -subj /CN=localhost +``` -# generate ed25519 signing key +Generate the server signing key: + +``` test -f matrix_key.pem || ./bin/generate-keys -private-key matrix_key.pem ``` -### Configuration +### Configuration file Create config file, based on `dendrite-config.yaml`. Call it `dendrite.yaml`. Things that will need editing include *at least*: -* `server_name` -* `database/*` (All lines in the database section must have the username and password of the user created with the `createuser` command above. eg:`dendrite:password@localhost`) +* The `server_name` entry to reflect the hostname of your Dendrite server +* The `database` lines with an updated connection string based on your + desired setup, e.g. replacing `component` with the name of the component: + * For Postgres: `postgres://dendrite:password@localhost/component` + * For SQLite on disk: `file:component.db` or `file:///path/to/component.db` + * Postgres and SQLite can be mixed and matched. +* The `use_naffka` option if using Naffka in a monolith deployment + +There are other options which may be useful so review them all. In particular, +if you are trying to federate from your Dendrite instance into public rooms +then configuring `key_perspectives` (like `matrix.org` in the sample) can +help to improve reliability considerably by allowing your homeserver to fetch +public keys for dead homeservers from somewhere else. ## Starting a monolith server -It is possible to use 'naffka' as an in-process replacement to Kafka when using -the monolith server. To do this, set `use_naffka: true` in `dendrite.yaml` and uncomment -the necessary line related to naffka in the `database` section. Be sure to update the -database username and password if needed. +It is possible to use Naffka as an in-process replacement to Kafka when using +the monolith server. To do this, set `use_naffka: true` in your `dendrite.yaml` configuration and uncomment the relevant Naffka line in the `database` section. +Be sure to update the database username and password if needed. The monolith server can be started as shown below. By default it listens for -HTTP connections on port 8008, so point your client at -`http://localhost:8008`. If you set `--tls-cert` and `--tls-key` as shown -below, it will also listen for HTTPS connections on port 8448. +HTTP connections on port 8008, so you can configure your Matrix client to use +`http://localhost:8008` as the server. If you set `--tls-cert` and `--tls-key` +as shown below, it will also listen for HTTPS connections on port 8448. ```bash ./bin/dendrite-monolith-server --tls-cert=server.crt --tls-key=server.key ``` -## Starting a multiprocess server +## Starting a polylith deployment The following contains scripts which will run all the required processes in order to point a Matrix client at Dendrite. Conceptually, you are wiring together to form the following diagram: @@ -170,9 +221,10 @@ Servers --->| federation-api-proxy |--------->| dendrite-federation-api-server | A ==> B = Kafka (A = producer, B = consumer) ``` -### Run a client api proxy +### Client proxy -This is what Matrix clients will talk to. If you use the script below, point your client at `http://localhost:8008`. +This is what Matrix clients will talk to. If you use the script below, point +your client at `http://localhost:8008`. ```bash ./bin/client-api-proxy \ @@ -183,51 +235,10 @@ This is what Matrix clients will talk to. If you use the script below, point you --public-rooms-api-server-url "http://localhost:7775" \ ``` -### Run a client api +### Federation proxy -This is what implements message sending. Clients talk to this via the proxy in order to send messages. - -```bash -./bin/dendrite-client-api-server --config=dendrite.yaml -``` - -(If this fails with `pq: syntax error at or near "ON"`, check you are using at least postgres 9.5.) - -### Run a room server - -This is what implements the room DAG. Clients do not talk to this. - -```bash -./bin/dendrite-room-server --config=dendrite.yaml -``` - -### Run a sync server - -This is what implements `/sync` requests. Clients talk to this via the proxy in order to receive messages. - -```bash -./bin/dendrite-sync-api-server --config dendrite.yaml -``` - -### Run a media server - -This implements `/media` requests. Clients talk to this via the proxy in order to upload and retrieve media. - -```bash -./bin/dendrite-media-api-server --config dendrite.yaml -``` - -### Run public room server - -This implements `/directory` requests. Clients talk to this via the proxy in order to retrieve room directory listings. - -```bash -./bin/dendrite-public-rooms-api-server --config dendrite.yaml -``` - -### Run a federation api proxy - -This is what Matrix servers will talk to. This is only required if you want to support federation. +This is what Matrix servers will talk to. This is only required if you want +to support federation. ```bash ./bin/federation-api-proxy \ @@ -236,7 +247,51 @@ This is what Matrix servers will talk to. This is only required if you want to s --media-api-server-url "http://localhost:7774" \ ``` -### Run a federation api server +### Client API server + +This is what implements message sending. Clients talk to this via the proxy in +order to send messages. + +```bash +./bin/dendrite-client-api-server --config=dendrite.yaml +``` + +### Room server + +This is what implements the room DAG. Clients do not talk to this. + +```bash +./bin/dendrite-room-server --config=dendrite.yaml +``` + +### Sync server + +This is what implements `/sync` requests. Clients talk to this via the proxy +in order to receive messages. + +```bash +./bin/dendrite-sync-api-server --config dendrite.yaml +``` + +### Media server + +This implements `/media` requests. Clients talk to this via the proxy in +order to upload and retrieve media. + +```bash +./bin/dendrite-media-api-server --config dendrite.yaml +``` + +### Public room server + +This implements `/directory` requests. Clients talk to this via the proxy +in order to retrieve room directory listings. + +```bash +./bin/dendrite-public-rooms-api-server --config dendrite.yaml +``` + +### Federation API server This implements federation requests. Servers talk to this via the proxy in order to send transactions. This is only required if you want to support @@ -246,7 +301,7 @@ federation. ./bin/dendrite-federation-api-server --config dendrite.yaml ``` -### Run a federation sender server +### Federation sender This sends events from our users to other servers. This is only required if you want to support federation. @@ -255,7 +310,7 @@ you want to support federation. ./bin/dendrite-federation-sender-server --config dendrite.yaml ``` -### Run an appservice server +### Appservice server This sends events from the network to [application services](https://matrix.org/docs/spec/application_service/unstable.html) @@ -265,3 +320,22 @@ application services on your homeserver. ```bash ./bin/dendrite-appservice-server --config dendrite.yaml ``` + +### Key server + +This manages end-to-end encryption keys (or rather, it will do when it's +finished). + +```bash +./bin/dendrite-key-server --config dendrite.yaml +``` + +### User server + +This manages user accounts, device access tokens and user account data, +amongst other things. + +```bash +./bin/dendrite-user-api-server --config dendrite.yaml +``` + diff --git a/docs/WIRING-Current.md b/docs/WIRING-Current.md new file mode 100644 index 000000000..ec539d4e9 --- /dev/null +++ b/docs/WIRING-Current.md @@ -0,0 +1,73 @@ +This document details how various components communicate with each other. There are two kinds of components: + - Public-facing: exposes CS/SS API endpoints and need to be routed to via client-api-proxy or equivalent. + - Internal-only: exposes internal APIs and produces Kafka events. + +## Internal HTTP APIs + +Not everything can be done using Kafka logs. For example, requesting the latest events in a room is much better suited to +a request/response model like HTTP or RPC. Therefore, components can expose "internal APIs" which sit outside of Kafka logs. +Note in Monolith mode these are actually direct function calls and are not serialised HTTP requests. + +``` + Tier 1 Sync PublicRooms FederationAPI ClientAPI MediaAPI +Public Facing | .-----1------` | | | | | | | | | + 2 | .-------3-----------------` | | | `--------|-|-|-|--11--------------------. + | | | .--------4----------------------------------` | | | | + | | | | .---5-----------` | | | | | | + | | | | | .---6----------------------------` | | | + | | | | | | | .-----7----------` | | + | | | | | | 8 | | 10 | + | | | | | | | | `---9----. | | + V V V V V V V V V V V + Tier 2 Roomserver EDUServer FedSender AppService KeyServer ServerKeyAPI +Internal only | `------------------------12----------^ ^ + `------------------------------------------------------------13----------` + + Client ---> Server +``` +- 1 (PublicRooms -> Roomserver): Calculating current auth for changing visibility +- 2 (Sync -> Roomserver): When making backfill requests +- 3 (FedAPI -> Roomserver): Calculating (prev/auth events) and sending new events, processing backfill/state/state_ids requests +- 4 (ClientAPI -> Roomserver): Calculating (prev/auth events) and sending new events, processing /state requests +- 5 (FedAPI -> EDUServer): Sending typing/send-to-device events +- 6 (ClientAPI -> EDUServer): Sending typing/send-to-device events +- 7 (ClientAPI -> FedSender): Handling directory lookups +- 8 (FedAPI -> FedSender): Resetting backoffs when receiving traffic from a server. Querying joined hosts when handling alias lookup requests +- 9 (FedAPI -> AppService): Working out if the client is an appservice user +- 10 (ClientAPI -> AppService): Working out if the client is an appservice user +- 11 (FedAPI -> ServerKeyAPI): Verifying incoming event signatures +- 12 (FedSender -> ServerKeyAPI): Verifying event signatures of responses (e.g from send_join) +- 13 (Roomserver -> ServerKeyAPI): Verifying event signatures of backfilled events + +In addition to this, all public facing components (Tier 1) talk to the `UserAPI` to verify access tokens and extract profile information where needed. + +## Kafka logs + +``` + .----1--------------------------------------------. + V | + Tier 1 Sync PublicRooms FederationAPI ClientAPI MediaAPI +Public Facing ^ ^ ^ ^ + | | | | + 2 | | | + | `-3------------. | + | | | | + | | | | + | .------4------` | | + | | .--------5-----|------------------------------` + | | | | + Tier 2 Roomserver EDUServer FedSender AppService KeyServer ServerKeyAPI +Internal only | | ^ ^ + | `-----6----------` | + `--------------------7--------` + + +Producer ----> Consumer +``` +- 1 (ClientAPI -> Sync): For tracking account data +- 2 (Roomserver -> Sync): For all data to send to clients +- 3 (EDUServer -> Sync): For typing/send-to-device data to send to clients +- 4 (Roomserver -> PublicRooms): For tracking the current room name/topic/joined count/etc. +- 5 (Roomserver -> ClientAPI): For tracking memberships for profile updates. +- 6 (EDUServer -> FedSender): For sending EDUs over federation +- 7 (Roomserver -> FedSender): For sending PDUs over federation, for tracking joined hosts. diff --git a/WIRING.md b/docs/WIRING.md similarity index 100% rename from WIRING.md rename to docs/WIRING.md diff --git a/p2p.md b/docs/p2p.md similarity index 69% rename from p2p.md rename to docs/p2p.md index 141aaa1fc..d69b47bea 100644 --- a/p2p.md +++ b/docs/p2p.md @@ -1,6 +1,6 @@ ## Peer-to-peer Matrix -These are the instructions for setting up P2P Dendrite, current as of March 2020. There's both Go stuff and JS stuff to do to set this up. +These are the instructions for setting up P2P Dendrite, current as of May 2020. There's both Go stuff and JS stuff to do to set this up. ### Dendrite @@ -28,14 +28,13 @@ Then use `/ip4/127.0.0.1/tcp/9090/ws/p2p-websocket-star/`. ### Riot-web -You need to check out these repos: +You need to check out this repo: ``` $ git clone git@github.com:matrix-org/go-http-js-libp2p.git -$ git clone git@github.com:matrix-org/go-sqlite3-js.git ``` -Make sure to `yarn install` in both of these repos. Then: +Make sure to `yarn install` in the repo. Then: - `$ cp "$(go env GOROOT)/misc/wasm/wasm_exec.js" ./src/vector/` - Comment out the lines in `wasm_exec.js` which contains: @@ -49,7 +48,6 @@ if (!global.fs && global.require) { - Add the following symlinks: they HAVE to be symlinks as the diff in `webpack.config.js` references specific paths. ``` $ cd node_modules -$ ln -s ../../go-sqlite-js # NB: NOT go-sqlite3-js $ ln -s ../../go-http-js-libp2p ``` @@ -65,14 +63,7 @@ You need a Chrome and a Firefox running to test locally as service workers don't Assuming you've `yarn start`ed Riot-Web, go to `http://localhost:8080` and register with `http://localhost:8080` as your HS URL. -You can join rooms by room alias e.g `/join #foo:bar`. - -### Known issues - -- When registering you may be unable to find the server, it'll seem flakey. This happens because the SW, particularly in Firefox, - gets killed after 30s of inactivity. When you are not registered, you aren't doing `/sync` calls to keep the SW alive, so if you - don't register for a while and idle on the page, the HS will disappear. To fix, unregister the SW, and then refresh the page. - -- The libp2p layer has rate limits, so frequent Federation traffic may cause the connection to drop and messages to not be transferred. - I guess in other words, don't send too much traffic? - +You can: + - join rooms by room alias e.g `/join #foo:bar`. + - invite specific users to a room. + - explore the published room list. All members of the room can re-publish aliases (unlike Synapse). diff --git a/eduserver/api/input.go b/eduserver/api/input.go index be2d4c56a..0d0d21f33 100644 --- a/eduserver/api/input.go +++ b/eduserver/api/input.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -15,12 +19,8 @@ package api import ( "context" - "errors" - "net/http" - commonHTTP "github.com/matrix-org/dendrite/common/http" "github.com/matrix-org/gomatrixserverlib" - opentracing "github.com/opentracing/opentracing-go" ) // InputTypingEvent is an event for notifying the typing server about typing updates. @@ -37,6 +37,12 @@ type InputTypingEvent struct { OriginServerTS gomatrixserverlib.Timestamp `json:"origin_server_ts"` } +type InputSendToDeviceEvent struct { + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + gomatrixserverlib.SendToDeviceEvent +} + // InputTypingEventRequest is a request to EDUServerInputAPI type InputTypingEventRequest struct { InputTypingEvent InputTypingEvent `json:"input_typing_event"` @@ -45,6 +51,14 @@ type InputTypingEventRequest struct { // InputTypingEventResponse is a response to InputTypingEvents type InputTypingEventResponse struct{} +// InputSendToDeviceEventRequest is a request to EDUServerInputAPI +type InputSendToDeviceEventRequest struct { + InputSendToDeviceEvent InputSendToDeviceEvent `json:"input_send_to_device_event"` +} + +// InputSendToDeviceEventResponse is a response to InputSendToDeviceEventRequest +type InputSendToDeviceEventResponse struct{} + // EDUServerInputAPI is used to write events to the typing server. type EDUServerInputAPI interface { InputTypingEvent( @@ -52,33 +66,10 @@ type EDUServerInputAPI interface { request *InputTypingEventRequest, response *InputTypingEventResponse, ) error -} - -// EDUServerInputTypingEventPath is the HTTP path for the InputTypingEvent API. -const EDUServerInputTypingEventPath = "/api/eduserver/input" - -// NewEDUServerInputAPIHTTP creates a EDUServerInputAPI implemented by talking to a HTTP POST API. -func NewEDUServerInputAPIHTTP(eduServerURL string, httpClient *http.Client) (EDUServerInputAPI, error) { - if httpClient == nil { - return nil, errors.New("NewTypingServerInputAPIHTTP: httpClient is ") - } - return &httpEDUServerInputAPI{eduServerURL, httpClient}, nil -} - -type httpEDUServerInputAPI struct { - eduServerURL string - httpClient *http.Client -} - -// InputRoomEvents implements EDUServerInputAPI -func (h *httpEDUServerInputAPI) InputTypingEvent( - ctx context.Context, - request *InputTypingEventRequest, - response *InputTypingEventResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputTypingEvent") - defer span.Finish() - - apiURL := h.eduServerURL + EDUServerInputTypingEventPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + + InputSendToDeviceEvent( + ctx context.Context, + request *InputSendToDeviceEventRequest, + response *InputSendToDeviceEventResponse, + ) error } diff --git a/eduserver/api/output.go b/eduserver/api/output.go index 8696acf49..e6ded8413 100644 --- a/eduserver/api/output.go +++ b/eduserver/api/output.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -12,7 +16,11 @@ package api -import "time" +import ( + "time" + + "github.com/matrix-org/gomatrixserverlib" +) // OutputTypingEvent is an entry in typing server output kafka log. // This contains the event with extra fields used to create 'm.typing' event @@ -32,3 +40,12 @@ type TypingEvent struct { UserID string `json:"user_id"` Typing bool `json:"typing"` } + +// OutputSendToDeviceEvent is an entry in the send-to-device output kafka log. +// This contains the full event content, along with the user ID and device ID +// to which it is destined. +type OutputSendToDeviceEvent struct { + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + gomatrixserverlib.SendToDeviceEvent +} diff --git a/eduserver/api/wrapper.go b/eduserver/api/wrapper.go new file mode 100644 index 000000000..c2c4596de --- /dev/null +++ b/eduserver/api/wrapper.go @@ -0,0 +1,69 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "encoding/json" + "time" + + "github.com/matrix-org/gomatrixserverlib" +) + +// SendTyping sends a typing event to EDU server +func SendTyping( + ctx context.Context, eduAPI EDUServerInputAPI, userID, roomID string, + typing bool, timeoutMS int64, +) error { + requestData := InputTypingEvent{ + UserID: userID, + RoomID: roomID, + Typing: typing, + TimeoutMS: timeoutMS, + OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + } + + var response InputTypingEventResponse + err := eduAPI.InputTypingEvent( + ctx, &InputTypingEventRequest{InputTypingEvent: requestData}, &response, + ) + + return err +} + +// SendToDevice sends a typing event to EDU server +func SendToDevice( + ctx context.Context, eduAPI EDUServerInputAPI, sender, userID, deviceID, eventType string, + message interface{}, +) error { + js, err := json.Marshal(message) + if err != nil { + return err + } + requestData := InputSendToDeviceEvent{ + UserID: userID, + DeviceID: deviceID, + SendToDeviceEvent: gomatrixserverlib.SendToDeviceEvent{ + Sender: sender, + Type: eventType, + Content: js, + }, + } + request := InputSendToDeviceEventRequest{ + InputSendToDeviceEvent: requestData, + } + response := InputSendToDeviceEventResponse{} + return eduAPI.InputSendToDeviceEvent(ctx, &request, &response) +} diff --git a/eduserver/cache/cache.go b/eduserver/cache/cache.go index 46f7a2b13..dd535a6d2 100644 --- a/eduserver/cache/cache.go +++ b/eduserver/cache/cache.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -109,6 +113,19 @@ func (t *EDUCache) AddTypingUser( return t.GetLatestSyncPosition() } +// AddSendToDeviceMessage increases the sync position for +// send-to-device updates. +// Returns the sync position before update, as the caller +// will use this to record the current stream position +// at the time that the send-to-device message was sent. +func (t *EDUCache) AddSendToDeviceMessage() int64 { + t.Lock() + defer t.Unlock() + latestSyncPosition := t.latestSyncPosition + t.latestSyncPosition++ + return latestSyncPosition +} + // addUser with mutex lock & replace the previous timer. // Returns the latest typing sync position after update. func (t *EDUCache) addUser( diff --git a/eduserver/cache/cache_test.go b/eduserver/cache/cache_test.go index 8a1b6f797..c7d01879f 100644 --- a/eduserver/cache/cache_test.go +++ b/eduserver/cache/cache_test.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -16,7 +20,7 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/common/test" + "github.com/matrix-org/dendrite/internal/test" ) func TestEDUCache(t *testing.T) { diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index 8ddd2c527..2e6ba0c85 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -13,28 +17,34 @@ package eduserver import ( - "net/http" - - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/input" + "github.com/matrix-org/dendrite/eduserver/inthttp" + "github.com/matrix-org/dendrite/internal/setup" + userapi "github.com/matrix-org/dendrite/userapi/api" ) -// SetupEDUServerComponent sets up and registers HTTP handlers for the -// EDUServer component. Returns instances of the various roomserver APIs, -// allowing other components running in the same process to hit the query the -// APIs directly instead of having to use HTTP. -func SetupEDUServerComponent( - base *basecomponent.BaseDendrite, - eduCache *cache.EDUCache, -) api.EDUServerInputAPI { - inputAPI := &input.EDUServerInputAPI{ - Cache: eduCache, - Producer: base.KafkaProducer, - OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent), - } - - inputAPI.SetupHTTP(http.DefaultServeMux) - return inputAPI +// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions +// on the given input API. +func AddInternalRoutes(internalMux *mux.Router, inputAPI api.EDUServerInputAPI) { + inthttp.AddRoutes(inputAPI, internalMux) +} + +// NewInternalAPI returns a concerete implementation of the internal API. Callers +// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. +func NewInternalAPI( + base *setup.BaseDendrite, + eduCache *cache.EDUCache, + userAPI userapi.UserInternalAPI, +) api.EDUServerInputAPI { + return &input.EDUServerInputAPI{ + Cache: eduCache, + UserAPI: userAPI, + Producer: base.KafkaProducer, + OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent), + OutputSendToDeviceEventTopic: string(base.Cfg.Kafka.Topics.OutputSendToDeviceEvent), + ServerName: base.Cfg.Matrix.ServerName, + } } diff --git a/eduserver/input/input.go b/eduserver/input/input.go index 845909452..e3d2c55e3 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -15,15 +19,14 @@ package input import ( "context" "encoding/json" - "net/http" "time" - "github.com/matrix-org/dendrite/common" + "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "gopkg.in/Shopify/sarama.v1" + "github.com/sirupsen/logrus" ) // EDUServerInputAPI implements api.EDUServerInputAPI @@ -32,8 +35,14 @@ type EDUServerInputAPI struct { Cache *cache.EDUCache // The kafka topic to output new typing events to. OutputTypingEventTopic string + // The kafka topic to output new send to device events to. + OutputSendToDeviceEventTopic string // kafka producer Producer sarama.SyncProducer + // Internal user query API + UserAPI userapi.UserInternalAPI + // our server name + ServerName gomatrixserverlib.ServerName } // InputTypingEvent implements api.EDUServerInputAPI @@ -53,10 +62,20 @@ func (t *EDUServerInputAPI) InputTypingEvent( t.Cache.RemoveUser(ite.UserID, ite.RoomID) } - return t.sendEvent(ite) + return t.sendTypingEvent(ite) } -func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { +// InputTypingEvent implements api.EDUServerInputAPI +func (t *EDUServerInputAPI) InputSendToDeviceEvent( + ctx context.Context, + request *api.InputSendToDeviceEventRequest, + response *api.InputSendToDeviceEventResponse, +) error { + ise := &request.InputSendToDeviceEvent + return t.sendToDeviceEvent(ise) +} + +func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error { ev := &api.TypingEvent{ Type: gomatrixserverlib.MTyping, RoomID: ite.RoomID, @@ -78,6 +97,11 @@ func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { if err != nil { return err } + logrus.WithFields(logrus.Fields{ + "room_id": ite.RoomID, + "user_id": ite.UserID, + "typing": ite.Typing, + }).Infof("Producing to topic '%s'", t.OutputTypingEventTopic) m := &sarama.ProducerMessage{ Topic: string(t.OutputTypingEventTopic), @@ -89,19 +113,63 @@ func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { return err } -// SetupHTTP adds the EDUServerInputAPI handlers to the http.ServeMux. -func (t *EDUServerInputAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle(api.EDUServerInputTypingEventPath, - common.MakeInternalAPI("inputTypingEvents", func(req *http.Request) util.JSONResponse { - var request api.InputTypingEventRequest - var response api.InputTypingEventResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := t.InputTypingEvent(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) +func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) error { + devices := []string{} + _, domain, err := gomatrixserverlib.SplitID('@', ise.UserID) + if err != nil { + return err + } + + // If the event is targeted locally then we want to expand the wildcard + // out into individual device IDs so that we can send them to each respective + // device. If the event isn't targeted locally then we can't expand the + // wildcard as we don't know about the remote devices, so instead we leave it + // as-is, so that the federation sender can send it on with the wildcard intact. + if domain == t.ServerName && ise.DeviceID == "*" { + var res userapi.QueryDevicesResponse + err = t.UserAPI.QueryDevices(context.TODO(), &userapi.QueryDevicesRequest{ + UserID: ise.UserID, + }, &res) + if err != nil { + return err + } + for _, dev := range res.Devices { + devices = append(devices, dev.ID) + } + } else { + devices = append(devices, ise.DeviceID) + } + + logrus.WithFields(logrus.Fields{ + "user_id": ise.UserID, + "num_devices": len(devices), + "type": ise.Type, + }).Infof("Producing to topic '%s'", t.OutputSendToDeviceEventTopic) + for _, device := range devices { + ote := &api.OutputSendToDeviceEvent{ + UserID: ise.UserID, + DeviceID: device, + SendToDeviceEvent: ise.SendToDeviceEvent, + } + + eventJSON, err := json.Marshal(ote) + if err != nil { + logrus.WithError(err).Error("sendToDevice failed json.Marshal") + return err + } + + m := &sarama.ProducerMessage{ + Topic: string(t.OutputSendToDeviceEventTopic), + Key: sarama.StringEncoder(ote.UserID), + Value: sarama.ByteEncoder(eventJSON), + } + + _, _, err = t.Producer.SendMessage(m) + if err != nil { + logrus.WithError(err).Error("sendToDevice failed t.Producer.SendMessage") + return err + } + } + + return nil } diff --git a/eduserver/inthttp/client.go b/eduserver/inthttp/client.go new file mode 100644 index 000000000..7d0bc1603 --- /dev/null +++ b/eduserver/inthttp/client.go @@ -0,0 +1,56 @@ +package inthttp + +import ( + "context" + "errors" + "net/http" + + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/opentracing/opentracing-go" +) + +// HTTP paths for the internal HTTP APIs +const ( + EDUServerInputTypingEventPath = "/eduserver/input" + EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice" +) + +// NewEDUServerClient creates a EDUServerInputAPI implemented by talking to a HTTP POST API. +func NewEDUServerClient(eduServerURL string, httpClient *http.Client) (api.EDUServerInputAPI, error) { + if httpClient == nil { + return nil, errors.New("NewEDUServerClient: httpClient is ") + } + return &httpEDUServerInputAPI{eduServerURL, httpClient}, nil +} + +type httpEDUServerInputAPI struct { + eduServerURL string + httpClient *http.Client +} + +// InputTypingEvent implements EDUServerInputAPI +func (h *httpEDUServerInputAPI) InputTypingEvent( + ctx context.Context, + request *api.InputTypingEventRequest, + response *api.InputTypingEventResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputTypingEvent") + defer span.Finish() + + apiURL := h.eduServerURL + EDUServerInputTypingEventPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// InputSendToDeviceEvent implements EDUServerInputAPI +func (h *httpEDUServerInputAPI) InputSendToDeviceEvent( + ctx context.Context, + request *api.InputSendToDeviceEventRequest, + response *api.InputSendToDeviceEventResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputSendToDeviceEvent") + defer span.Finish() + + apiURL := h.eduServerURL + EDUServerInputSendToDeviceEventPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/eduserver/inthttp/server.go b/eduserver/inthttp/server.go new file mode 100644 index 000000000..e374513a3 --- /dev/null +++ b/eduserver/inthttp/server.go @@ -0,0 +1,41 @@ +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/util" +) + +// AddRoutes adds the EDUServerInputAPI handlers to the http.ServeMux. +func AddRoutes(t api.EDUServerInputAPI, internalAPIMux *mux.Router) { + internalAPIMux.Handle(EDUServerInputTypingEventPath, + httputil.MakeInternalAPI("inputTypingEvents", func(req *http.Request) util.JSONResponse { + var request api.InputTypingEventRequest + var response api.InputTypingEventResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := t.InputTypingEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(EDUServerInputSendToDeviceEventPath, + httputil.MakeInternalAPI("inputSendToDeviceEvents", func(req *http.Request) util.JSONResponse { + var request api.InputSendToDeviceEventRequest + var response api.InputSendToDeviceEventResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := t.InputSendToDeviceEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index ed96322b8..c0c000434 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -15,39 +15,32 @@ package federationapi import ( - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/gorilla/mux" + eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" - // TODO: Are we really wanting to pull in the producer from clientapi - "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/federationapi/routing" "github.com/matrix-org/gomatrixserverlib" ) -// SetupFederationAPIComponent sets up and registers HTTP handlers for the -// FederationAPI component. -func SetupFederationAPIComponent( - base *basecomponent.BaseDendrite, - accountsDB accounts.Database, - deviceDB devices.Database, +// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. +func AddPublicRoutes( + router *mux.Router, + cfg *config.Dendrite, + userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, - keyRing *gomatrixserverlib.KeyRing, - aliasAPI roomserverAPI.RoomserverAliasAPI, - inputAPI roomserverAPI.RoomserverInputAPI, - queryAPI roomserverAPI.RoomserverQueryAPI, - asAPI appserviceAPI.AppServiceQueryAPI, - federationSenderAPI federationSenderAPI.FederationSenderQueryAPI, - eduProducer *producers.EDUServerProducer, + keyRing gomatrixserverlib.JSONVerifier, + rsAPI roomserverAPI.RoomserverInternalAPI, + federationSenderAPI federationSenderAPI.FederationSenderInternalAPI, + eduAPI eduserverAPI.EDUServerInputAPI, ) { - roomserverProducer := producers.NewRoomserverProducer(inputAPI, queryAPI) routing.Setup( - base.APIMux, base.Cfg, queryAPI, aliasAPI, asAPI, - roomserverProducer, eduProducer, federationSenderAPI, *keyRing, - federation, accountsDB, deviceDB, + router, cfg, rsAPI, + eduAPI, federationSenderAPI, keyRing, + federation, userAPI, ) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go new file mode 100644 index 000000000..cc85c61bf --- /dev/null +++ b/federationapi/federationapi_test.go @@ -0,0 +1,102 @@ +package federationapi_test + +import ( + "context" + "crypto/ed25519" + "strings" + "testing" + + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" +) + +// Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404. +// Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated. +func TestRoomsV3URLEscapeDoNot404(t *testing.T) { + _, privKey, _ := ed25519.GenerateKey(nil) + cfg := &config.Dendrite{} + cfg.Matrix.KeyID = gomatrixserverlib.KeyID("ed25519:auto") + cfg.Matrix.ServerName = gomatrixserverlib.ServerName("localhost") + cfg.Matrix.PrivateKey = privKey + cfg.Kafka.UseNaffka = true + cfg.Database.Naffka = "file::memory:" + cfg.SetDefaults() + base := setup.NewBaseDendrite(cfg, "Test", false) + keyRing := &test.NopJSONVerifier{} + fsAPI := base.FederationSenderHTTPClient() + // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. + // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. + federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil) + httputil.SetupHTTPAPI( + base.BaseMux, + base.PublicAPIMux, + base.InternalAPIMux, + cfg, + base.UseHTTPAPIs, + ) + baseURL, cancel := test.ListenAndServe(t, base.BaseMux, true) + defer cancel() + serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) + + fedCli := gomatrixserverlib.NewFederationClient(serverName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) + + testCases := []struct { + roomVer gomatrixserverlib.RoomVersion + eventJSON string + }{ + { + eventJSON: `{"auth_events":[["$Nzfbrhc3oaYVKzGM:localhost",{"sha256":"BCBHOgB4qxLPQkBd6th8ydFSyqjth/LF99VNjYffOQ0"}],["$EZzkD2BH1Gtm5v1D:localhost",{"sha256":"3dLUnDBs8/iC5DMw/ydKtmAqVZtzqqtHpsjsQPk7GJA"}]],"content":{"body":"Test Message"},"depth":11,"event_id":"$mGiPO3oGjQfCkIUw:localhost","hashes":{"sha256":"h+t+4DwIBC9UNyJ3jzyAQAAl4H3yQHVuHrm2S1JZizU"},"origin":"localhost","origin_server_ts":0,"prev_events":[["$tFr64vpiSHdLU0Qr:localhost",{"sha256":"+R07ZrIs4c4tjPFE+tmcYIGUfeLGFI/4e0OITb9uEcM"}]],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"LYFr/rW9m5/7UKBQMF5qWnG82He4VGsRESUgDmvkn5DrJRyS4TLL/7zl0Lymn3pa3q2yaTO74LQX/CRotqG1BA"}},"type":"m.room.message"}`, + roomVer: gomatrixserverlib.RoomVersionV1, + }, + // single / (handlers which do not UseEncodedPath will fail this test) + // EventID: $0SFh2WJbjBs3OT+E0yl95giDKo/3Zp52HsHUUk4uPyg + { + eventJSON: `{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":8,"hashes":{"sha256":"dfK0MBn1RZZqCVJqWsn/MGY7QJHjQcwqF0unOonLCTU"},"origin":"localhost","origin_server_ts":0,"prev_events":["$1SwcZ1XY/Y8yKLjP4DzAOHN5WFBcDAZxb5vFDnW2ubA"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"INOjuWMg+GmFkUpmzhMB0bqLNs73mSvwldY1ftYIQ/B3lD9soD2OMG3AF+wgZW/I8xqzY4DOHfbnbUeYPf67BA"}},"type":"m.room.message"}`, + roomVer: gomatrixserverlib.RoomVersionV3, + }, + // multiple / + // EventID: $OzENBCuVv/fnRAYCeQudIon/84/V5pxtEjQMTgi3emk + { + eventJSON: `{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":2,"hashes":{"sha256":"U5+WsiJAhiEM88J8HTjuUjPImVGVzDFD3v/WS+jb2f0"},"origin":"localhost","origin_server_ts":0,"prev_events":["$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"tKS469e9+wdWPEKB/LbBJWQ8vfOOdKgTWER5IwbSAH1CxmLvkCziUsgVu85zfzDSLoUi5mU5FHLiMTC6P/qICw"}},"type":"m.room.message"}`, + roomVer: gomatrixserverlib.RoomVersionV3, + }, + // two slashes (handlers which clean paths before UseEncodedPath will fail this test) + // EventID: $EmwNBlHoSOVmCZ1cM//yv/OvxB6r4OFEIGSJea7+Amk + { + eventJSON: `{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`, + roomVer: gomatrixserverlib.RoomVersionV3, + }, + } + + for _, tc := range testCases { + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(tc.eventJSON), false, tc.roomVer) + if err != nil { + t.Errorf("failed to parse event: %s", err) + } + he := ev.Headered(tc.roomVer) + invReq, err := gomatrixserverlib.NewInviteV2Request(&he, nil) + if err != nil { + t.Errorf("failed to create invite v2 request: %s", err) + continue + } + _, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) + if err == nil { + t.Errorf("expected an error, got none") + continue + } + gerr, ok := err.(gomatrix.HTTPError) + if !ok { + t.Errorf("failed to cast response error as gomatrix.HTTPError") + continue + } + t.Logf("Error: %+v", gerr) + if gerr.Code == 404 { + t.Errorf("invite event resulted in a 404") + } + } +} diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 62471b8a9..f906c73c9 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -22,7 +22,7 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -33,11 +33,11 @@ import ( func Backfill( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, - query api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, roomID string, cfg *config.Dendrite, ) util.JSONResponse { - var res api.QueryBackfillResponse + var res api.PerformBackfillResponse var eIDs []string var limit string var exists bool @@ -68,9 +68,15 @@ func Backfill( } // Populate the request. - req := api.QueryBackfillRequest{ - EarliestEventsIDs: eIDs, - ServerName: request.Origin(), + req := api.PerformBackfillRequest{ + RoomID: roomID, + // we don't know who the successors are for these events, which won't + // be a problem because we don't use that information when servicing /backfill requests, + // only when making them. TODO: Think of a better API shape + BackwardsExtremities: map[string][]string{ + "": eIDs, + }, + ServerName: request.Origin(), } if req.Limit, err = strconv.Atoi(limit); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") @@ -81,8 +87,8 @@ func Backfill( } // Query the roomserver. - if err = query.QueryBackfill(httpReq.Context(), &req, &res); err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryBackfill failed") + if err = rsAPI.PerformBackfill(httpReq.Context(), &req, &res); err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("query.PerformBackfill failed") return jsonerror.InternalServerError() } @@ -96,11 +102,20 @@ func Backfill( } } - var eventJSONs []json.RawMessage - for _, e := range gomatrixserverlib.ReverseTopologicalOrdering(evs) { + eventJSONs := []json.RawMessage{} + for _, e := range gomatrixserverlib.ReverseTopologicalOrdering( + evs, + gomatrixserverlib.TopologicalOrderByPrevEvents, + ) { eventJSONs = append(eventJSONs, e.JSON()) } + // sytest wants these in reversed order, similar to /messages, so reverse them now. + for i := len(eventJSONs)/2 - 1; i >= 0; i-- { + opp := len(eventJSONs) - 1 - i + eventJSONs[i], eventJSONs[opp] = eventJSONs[opp], eventJSONs[i] + } + txn := gomatrixserverlib.Transaction{ Origin: cfg.Matrix.ServerName, PDUs: eventJSONs, diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 01647a61e..6369c708c 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -15,39 +15,45 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/userutil" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -type userDevicesResponse struct { - Devices []authtypes.Device `json:"devices"` -} - // GetUserDevices for the given user id func GetUserDevices( req *http.Request, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, userID string, ) util.JSONResponse { - localpart, err := userutil.ParseUsernameParam(userID, nil) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("Invalid user ID"), - } + response := gomatrixserverlib.RespUserDevices{ + UserID: userID, + // TODO: we should return an incrementing stream ID each time the device + // list changes for delta changes to be recognised + StreamID: 0, } - devs, err := deviceDB.GetDevicesByLocalpart(req.Context(), localpart) + var res userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{ + UserID: userID, + }, &res) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalPart failed") + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryDevices failed") return jsonerror.InternalServerError() } + for _, dev := range res.Devices { + device := gomatrixserverlib.RespUserDevice{ + DeviceID: dev.ID, + DisplayName: dev.DisplayName, + Keys: []gomatrixserverlib.RespUserDeviceKeys{}, + } + response.Devices = append(response.Devices, device) + } + return util.JSONResponse{ Code: 200, - JSON: userDevicesResponse{devs}, + JSON: response, } } diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index 003165c85..34eaad1c5 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -25,13 +25,13 @@ import ( func GetEventAuth( ctx context.Context, request *gomatrixserverlib.FederationRequest, - query api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, roomID string, eventID string, ) util.JSONResponse { // TODO: Optimisation: we shouldn't be querying all the room state // that is in state.StateEvents - we just ignore it. - state, err := getState(ctx, request, query, roomID, eventID) + state, err := getState(ctx, request, rsAPI, roomID, eventID) if err != nil { return *err } diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index a91528b3d..6fa28f69d 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -16,7 +16,9 @@ package routing import ( "context" + "encoding/json" "net/http" + "time" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" @@ -27,46 +29,62 @@ import ( func GetEvent( ctx context.Context, request *gomatrixserverlib.FederationRequest, - query api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, eventID string, + origin gomatrixserverlib.ServerName, ) util.JSONResponse { - event, err := getEvent(ctx, request, query, eventID) + err := allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) + if err != nil { + return *err + } + event, err := fetchEvent(ctx, rsAPI, eventID) if err != nil { return *err } - return util.JSONResponse{Code: http.StatusOK, JSON: event} + return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{ + Origin: origin, + OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + PDUs: []json.RawMessage{ + event.JSON(), + }, + }} } -// getEvent returns the requested event, +// allowedToSeeEvent returns no error if the server is allowed to see this event, // otherwise it returns an error response which can be sent to the client. -func getEvent( +func allowedToSeeEvent( ctx context.Context, - request *gomatrixserverlib.FederationRequest, - query api.RoomserverQueryAPI, + origin gomatrixserverlib.ServerName, + rsAPI api.RoomserverInternalAPI, eventID string, -) (*gomatrixserverlib.Event, *util.JSONResponse) { +) *util.JSONResponse { var authResponse api.QueryServerAllowedToSeeEventResponse - err := query.QueryServerAllowedToSeeEvent( + err := rsAPI.QueryServerAllowedToSeeEvent( ctx, &api.QueryServerAllowedToSeeEventRequest{ EventID: eventID, - ServerName: request.Origin(), + ServerName: origin, }, &authResponse, ) if err != nil { resErr := util.ErrorResponse(err) - return nil, &resErr + return &resErr } if !authResponse.AllowedToSeeEvent { resErr := util.MessageResponse(http.StatusForbidden, "server not allowed to see event") - return nil, &resErr + return &resErr } + return nil +} + +// fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found. +func fetchEvent(ctx context.Context, rsAPI api.RoomserverInternalAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { var eventsResponse api.QueryEventsByIDResponse - err = query.QueryEventsByID( + err := rsAPI.QueryEventsByID( ctx, &api.QueryEventsByIDRequest{EventIDs: []string{eventID}}, &eventsResponse, diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 4b367e004..b1d84f254 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -16,11 +16,13 @@ package routing import ( "encoding/json" + "fmt" "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/roomserver/api" + roomserverVersion "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -32,8 +34,8 @@ func Invite( roomID string, eventID string, cfg *config.Dendrite, - producer *producers.RoomserverProducer, - keys gomatrixserverlib.KeyRing, + rsAPI api.RoomserverInternalAPI, + keys gomatrixserverlib.JSONVerifier, ) util.JSONResponse { inviteReq := gomatrixserverlib.InviteV2Request{} if err := json.Unmarshal(request.Content(), &inviteReq); err != nil { @@ -44,6 +46,16 @@ func Invite( } event := inviteReq.Event() + // Check that we can accept invites for this room version. + if _, err := roomserverVersion.SupportedRoomVersion(inviteReq.RoomVersion()); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UnsupportedRoomVersion( + fmt.Sprintf("Room version %q is not supported by this server.", inviteReq.RoomVersion()), + ), + } + } + // Check that the room ID is correct. if event.RoomID() != roomID { return util.JSONResponse{ @@ -86,19 +98,21 @@ func Invite( ) // Add the invite event to the roomserver. - if err = producer.SendInvite( - httpReq.Context(), + if perr := api.SendInvite( + httpReq.Context(), rsAPI, signedEvent.Headered(inviteReq.RoomVersion()), inviteReq.InviteRoomState(), - ); err != nil { + event.Origin(), + nil, + ); perr != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendInvite failed") - return jsonerror.InternalServerError() + return perr.JSONResponse() } // Return the signed event to the originating server, it should then tell // the other servers in the room that we have been invited. return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespInvite{Event: signedEvent}, + JSON: gomatrixserverlib.RespInviteV2{Event: signedEvent}, } } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index e06785954..8dcd15333 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -17,12 +17,12 @@ package routing import ( "fmt" "net/http" + "sort" "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -33,13 +33,13 @@ func MakeJoin( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, cfg *config.Dendrite, - query api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, roomID, userID string, remoteVersions []gomatrixserverlib.RoomVersion, ) util.JSONResponse { verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := query.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), @@ -61,9 +61,7 @@ func MakeJoin( if !remoteSupportsVersion { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion( - fmt.Sprintf("Joining server does not support room version %s", verRes.RoomVersion), - ), + JSON: jsonerror.IncompatibleRoomVersion(verRes.RoomVersion), } } @@ -97,14 +95,19 @@ func MakeJoin( queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: verRes.RoomVersion, } - event, err := common.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), query, &queryRes) - if err == common.ErrRoomNoExists { + event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), rsAPI, &queryRes) + if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Room does not exist"), } + } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(e.Error()), + } } else if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("common.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") return jsonerror.InternalServerError() } @@ -132,19 +135,21 @@ func MakeJoin( } // SendJoin implements the /send_join API +// The make-join send-join dance makes much more sense as a single +// flow so the cyclomatic complexity is high: +// nolint:gocyclo func SendJoin( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, cfg *config.Dendrite, - query api.RoomserverQueryAPI, - producer *producers.RoomserverProducer, - keys gomatrixserverlib.KeyRing, + rsAPI api.RoomserverInternalAPI, + keys gomatrixserverlib.JSONVerifier, roomID, eventID string, ) util.JSONResponse { verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := query.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryRoomVersionForRoom failed") + if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), @@ -155,7 +160,17 @@ func SendJoin( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON: " + err.Error()), + } + } + + // Check that a state key is provided. + if event.StateKey() == nil || (event.StateKey() != nil && *event.StateKey() == "") { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON( + fmt.Sprintf("No state key was provided in the join event."), + ), } } @@ -216,14 +231,14 @@ func SendJoin( // Fetch the state and auth chain. We do this before we send the events // on, in case this fails. var stateAndAuthChainResponse api.QueryStateAndAuthChainResponse - err = query.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{ + err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{ PrevEventIDs: event.PrevEventIDs(), AuthEventIDs: event.AuthEventIDs(), RoomID: roomID, ResolveState: true, }, &stateAndAuthChainResponse) if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryStateAndAuthChain failed") + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryStateAndAuthChain failed") return jsonerror.InternalServerError() } @@ -234,27 +249,61 @@ func SendJoin( } } + // Check if the user is already in the room. If they're already in then + // there isn't much point in sending another join event into the room. + alreadyJoined := false + for _, se := range stateAndAuthChainResponse.StateEvents { + if membership, merr := se.Membership(); merr == nil { + if se.StateKey() != nil && *se.StateKey() == *event.StateKey() { + alreadyJoined = (membership == "join") + break + } + } + } + // Send the events to the room server. // We are responsible for notifying other servers that the user has joined // the room, so set SendAsServer to cfg.Matrix.ServerName - _, err = producer.SendEvents( - httpReq.Context(), - []gomatrixserverlib.HeaderedEvent{ - event.Headered(stateAndAuthChainResponse.RoomVersion), - }, - cfg.Matrix.ServerName, - nil, - ) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendEvents failed") - return jsonerror.InternalServerError() + if !alreadyJoined { + _, err = api.SendEvents( + httpReq.Context(), rsAPI, + []gomatrixserverlib.HeaderedEvent{ + event.Headered(stateAndAuthChainResponse.RoomVersion), + }, + cfg.Matrix.ServerName, + nil, + ) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed") + return jsonerror.InternalServerError() + } } + // sort events deterministically by depth (lower is earlier) + // We also do this because sytest's basic federation server isn't good at using the correct + // state if these lists are randomised, resulting in flakey tests. :( + sort.Sort(eventsByDepth(stateAndAuthChainResponse.StateEvents)) + sort.Sort(eventsByDepth(stateAndAuthChainResponse.AuthChainEvents)) + + // https://matrix.org/docs/spec/server_server/latest#put-matrix-federation-v1-send-join-roomid-eventid return util.JSONResponse{ Code: http.StatusOK, - JSON: map[string]interface{}{ - "state": gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.StateEvents), - "auth_chain": gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.AuthChainEvents), + JSON: gomatrixserverlib.RespSendJoin{ + StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.StateEvents), + AuthEvents: gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.AuthChainEvents), + Origin: cfg.Matrix.ServerName, }, } } + +type eventsByDepth []gomatrixserverlib.HeaderedEvent + +func (e eventsByDepth) Len() int { + return len(e) +} +func (e eventsByDepth) Swap(i, j int) { + e[i], e[j] = e[j], e[i] +} +func (e eventsByDepth) Less(i, j int) bool { + return e[i].Depth() < e[j].Depth() +} diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 3eb88567d..a1dd0fd09 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -19,7 +19,7 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "golang.org/x/crypto/ed25519" @@ -44,7 +44,7 @@ func localKeys(cfg *config.Dendrite, validUntil time.Time) (*gomatrixserverlib.S keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ cfg.Matrix.KeyID: { - Key: gomatrixserverlib.Base64String(publicKey), + Key: gomatrixserverlib.Base64Bytes(publicKey), }, } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 6fc3b12ed..108fc50ae 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -17,9 +17,8 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -30,9 +29,18 @@ func MakeLeave( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, cfg *config.Dendrite, - query api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, roomID, userID string, ) util.JSONResponse { + verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} + verRes := api.QueryRoomVersionForRoomResponse{} + if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.InternalServerError(), + } + } + _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return util.JSONResponse{ @@ -61,14 +69,19 @@ func MakeLeave( } var queryRes api.QueryLatestEventsAndStateResponse - event, err := common.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), query, &queryRes) - if err == common.ErrRoomNoExists { + event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), rsAPI, &queryRes) + if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Room does not exist"), } + } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(e.Error()), + } } else if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("common.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") return jsonerror.InternalServerError() } @@ -87,7 +100,10 @@ func MakeLeave( return util.JSONResponse{ Code: http.StatusOK, - JSON: map[string]interface{}{"event": builder}, + JSON: map[string]interface{}{ + "room_version": verRes.RoomVersion, + "event": builder, + }, } } @@ -96,13 +112,13 @@ func SendLeave( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, cfg *config.Dendrite, - producer *producers.RoomserverProducer, - keys gomatrixserverlib.KeyRing, + rsAPI api.RoomserverInternalAPI, + keys gomatrixserverlib.JSONVerifier, roomID, eventID string, ) util.JSONResponse { verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := producer.QueryAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.UnsupportedRoomVersion(err.Error()), @@ -177,8 +193,8 @@ func SendLeave( // Send the events to the room server. // We are responsible for notifying other servers that the user has left // the room, so set SendAsServer to cfg.Matrix.ServerName - _, err = producer.SendEvents( - httpReq.Context(), + _, err = api.SendEvents( + httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ event.Headered(verRes.RoomVersion), }, diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index 069bff3dd..f93e0eb41 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -34,7 +34,7 @@ type getMissingEventRequest struct { func GetMissingEvents( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, - query api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, roomID string, ) util.JSONResponse { var gme getMissingEventRequest @@ -46,7 +46,7 @@ func GetMissingEvents( } var eventsResponse api.QueryMissingEventsResponse - if err := query.QueryMissingEvents( + if err := rsAPI.QueryMissingEvents( httpReq.Context(), &api.QueryMissingEventsRequest{ EarliestEvents: gme.EarliestEvents, LatestEvents: gme.LatestEvents, @@ -60,9 +60,14 @@ func GetMissingEvents( } eventsResponse.Events = filterEvents(eventsResponse.Events, gme.MinDepth, roomID) + + resp := gomatrixserverlib.RespMissingEvents{ + Events: gomatrixserverlib.UnwrapEventHeaders(eventsResponse.Events), + } + return util.JSONResponse{ Code: http.StatusOK, - JSON: eventsResponse, + JSON: resp, } } diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index 9a81c1b33..a6180ae6d 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -18,11 +18,10 @@ import ( "fmt" "net/http" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -30,9 +29,8 @@ import ( // GetProfile implements GET /_matrix/federation/v1/query/profile func GetProfile( httpReq *http.Request, - accountDB accounts.Database, + userAPI userapi.UserInternalAPI, cfg *config.Dendrite, - asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { userID, field := httpReq.FormValue("user_id"), httpReq.FormValue("field") @@ -60,9 +58,12 @@ func GetProfile( } } - profile, err := appserviceAPI.RetrieveUserProfile(httpReq.Context(), userID, asAPI, accountDB) + var profileRes userapi.QueryProfileResponse + err = userAPI.QueryProfile(httpReq.Context(), &userapi.QueryProfileRequest{ + UserID: userID, + }, &profileRes) if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") + util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryProfile failed") return jsonerror.InternalServerError() } @@ -72,21 +73,21 @@ func GetProfile( if field != "" { switch field { case "displayname": - res = common.DisplayName{ - DisplayName: profile.DisplayName, + res = eventutil.DisplayName{ + DisplayName: profileRes.DisplayName, } case "avatar_url": - res = common.AvatarURL{ - AvatarURL: profile.AvatarURL, + res = eventutil.AvatarURL{ + AvatarURL: profileRes.AvatarURL, } default: code = http.StatusBadRequest res = jsonerror.InvalidArgumentValue("The request body did not contain an allowed value of argument 'field'. Allowed values are either: 'avatar_url', 'displayname'.") } } else { - res = common.ProfileResponse{ - AvatarURL: profile.AvatarURL, - DisplayName: profile.DisplayName, + res = eventutil.ProfileResponse{ + AvatarURL: profileRes.AvatarURL, + DisplayName: profileRes.DisplayName, } } diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 7cb50e525..39fd6d2ee 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -19,8 +19,8 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" @@ -32,8 +32,8 @@ func RoomAliasToID( httpReq *http.Request, federation *gomatrixserverlib.FederationClient, cfg *config.Dendrite, - aliasAPI roomserverAPI.RoomserverAliasAPI, - senderAPI federationSenderAPI.FederationSenderQueryAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, + senderAPI federationSenderAPI.FederationSenderInternalAPI, ) util.JSONResponse { roomAlias := httpReq.FormValue("room_alias") if roomAlias == "" { @@ -55,7 +55,7 @@ func RoomAliasToID( if domain == cfg.Matrix.ServerName { queryReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: roomAlias} var queryRes roomserverAPI.GetRoomIDForAliasResponse - if err = aliasAPI.GetRoomIDForAlias(httpReq.Context(), &queryReq, &queryRes); err != nil { + if err = rsAPI.GetRoomIDForAlias(httpReq.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") return jsonerror.InternalServerError() } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 83bac5550..645f397de 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -18,48 +18,50 @@ import ( "net/http" "github.com/gorilla/mux" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) const ( - pathPrefixV2Keys = "/_matrix/key/v2" - pathPrefixV1Federation = "/_matrix/federation/v1" - pathPrefixV2Federation = "/_matrix/federation/v2" + pathPrefixV2Keys = "/key/v2" + pathPrefixV1Federation = "/federation/v1" + pathPrefixV2Federation = "/federation/v2" ) // Setup registers HTTP handlers with the given ServeMux. +// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly +// path unescape twice (once from the router, once from MakeFedAPI). We need to have this enabled +// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz] // // Due to Setup being used to call many other functions, a gocyclo nolint is // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, + publicAPIMux *mux.Router, cfg *config.Dendrite, - query roomserverAPI.RoomserverQueryAPI, - aliasAPI roomserverAPI.RoomserverAliasAPI, - asAPI appserviceAPI.AppServiceQueryAPI, - producer *producers.RoomserverProducer, - eduProducer *producers.EDUServerProducer, - federationSenderAPI federationSenderAPI.FederationSenderQueryAPI, - keys gomatrixserverlib.KeyRing, + rsAPI roomserverAPI.RoomserverInternalAPI, + eduAPI eduserverAPI.EDUServerInputAPI, + fsAPI federationSenderAPI.FederationSenderInternalAPI, + keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) { - v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter() - v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter() - v2fedmux := apiMux.PathPrefix(pathPrefixV2Federation).Subrouter() + v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter() + v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter() + v2fedmux := publicAPIMux.PathPrefix(pathPrefixV2Federation).Subrouter() - localKeys := common.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { + wakeup := &httputil.FederationWakeups{ + FsAPI: fsAPI, + } + + localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { return LocalKeys(cfg) }) @@ -71,140 +73,107 @@ func Setup( v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet) - v1fedmux.Handle("/send/{txnID}", common.MakeFedAPI( - "federation_send", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI( + "federation_send", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, query, producer, eduProducer, keys, federation, + cfg, rsAPI, eduAPI, keys, federation, ) }, )).Methods(http.MethodPut, http.MethodOptions) - v2fedmux.Handle("/invite/{roomID}/{eventID}", common.MakeFedAPI( - "federation_invite", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v2fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI( + "federation_invite", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Invite( httpReq, request, vars["roomID"], vars["eventID"], - cfg, producer, keys, + cfg, rsAPI, keys, ) }, )).Methods(http.MethodPut, http.MethodOptions) - v1fedmux.Handle("/3pid/onbind", common.MakeExternalAPI("3pid_onbind", + v1fedmux.Handle("/3pid/onbind", httputil.MakeExternalAPI("3pid_onbind", func(req *http.Request) util.JSONResponse { - return CreateInvitesFrom3PIDInvites(req, query, asAPI, cfg, producer, federation, accountDB) + return CreateInvitesFrom3PIDInvites(req, rsAPI, cfg, federation, userAPI) }, )).Methods(http.MethodPost, http.MethodOptions) - v1fedmux.Handle("/exchange_third_party_invite/{roomID}", common.MakeFedAPI( - "exchange_third_party_invite", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v1fedmux.Handle("/exchange_third_party_invite/{roomID}", httputil.MakeFedAPI( + "exchange_third_party_invite", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ExchangeThirdPartyInvite( - httpReq, request, vars["roomID"], query, cfg, federation, producer, + httpReq, request, vars["roomID"], rsAPI, cfg, federation, ) }, )).Methods(http.MethodPut, http.MethodOptions) - v1fedmux.Handle("/event/{eventID}", common.MakeFedAPI( - "federation_get_event", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v1fedmux.Handle("/event/{eventID}", httputil.MakeFedAPI( + "federation_get_event", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetEvent( - httpReq.Context(), request, query, vars["eventID"], + httpReq.Context(), request, rsAPI, vars["eventID"], cfg.Matrix.ServerName, ) }, )).Methods(http.MethodGet) - v1fedmux.Handle("/state/{roomID}", common.MakeFedAPI( - "federation_get_state", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v1fedmux.Handle("/state/{roomID}", httputil.MakeFedAPI( + "federation_get_state", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetState( - httpReq.Context(), request, query, vars["roomID"], + httpReq.Context(), request, rsAPI, vars["roomID"], ) }, )).Methods(http.MethodGet) - v1fedmux.Handle("/state_ids/{roomID}", common.MakeFedAPI( - "federation_get_state_ids", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v1fedmux.Handle("/state_ids/{roomID}", httputil.MakeFedAPI( + "federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetStateIDs( - httpReq.Context(), request, query, vars["roomID"], + httpReq.Context(), request, rsAPI, vars["roomID"], ) }, )).Methods(http.MethodGet) - v1fedmux.Handle("/event_auth/{roomID}/{eventID}", common.MakeFedAPI( - "federation_get_event_auth", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars := mux.Vars(httpReq) + v1fedmux.Handle("/event_auth/{roomID}/{eventID}", httputil.MakeFedAPI( + "federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetEventAuth( - httpReq.Context(), request, query, vars["roomID"], vars["eventID"], + httpReq.Context(), request, rsAPI, vars["roomID"], vars["eventID"], ) }, )).Methods(http.MethodGet) - v1fedmux.Handle("/query/directory", common.MakeFedAPI( - "federation_query_room_alias", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { + v1fedmux.Handle("/query/directory", httputil.MakeFedAPI( + "federation_query_room_alias", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return RoomAliasToID( - httpReq, federation, cfg, aliasAPI, federationSenderAPI, + httpReq, federation, cfg, rsAPI, fsAPI, ) }, )).Methods(http.MethodGet) - v1fedmux.Handle("/query/profile", common.MakeFedAPI( - "federation_query_profile", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { + v1fedmux.Handle("/query/profile", httputil.MakeFedAPI( + "federation_query_profile", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetProfile( - httpReq, accountDB, cfg, asAPI, + httpReq, userAPI, cfg, ) }, )).Methods(http.MethodGet) - v1fedmux.Handle("/user/devices/{userID}", common.MakeFedAPI( - "federation_user_devices", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v1fedmux.Handle("/user/devices/{userID}", httputil.MakeFedAPI( + "federation_user_devices", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( - httpReq, deviceDB, vars["userID"], + httpReq, userAPI, vars["userID"], ) }, )).Methods(http.MethodGet) - v1fedmux.Handle("/make_join/{roomID}/{eventID}", common.MakeFedAPI( - "federation_make_join", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v1fedmux.Handle("/make_join/{roomID}/{eventID}", httputil.MakeFedAPI( + "federation_make_join", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { roomID := vars["roomID"] eventID := vars["eventID"] queryVars := httpReq.URL.Query() @@ -222,82 +191,88 @@ func Setup( remoteVersions = append(remoteVersions, gomatrixserverlib.RoomVersionV1) } return MakeJoin( - httpReq, request, cfg, query, roomID, eventID, remoteVersions, + httpReq, request, cfg, rsAPI, roomID, eventID, remoteVersions, ) }, )).Methods(http.MethodGet) - v2fedmux.Handle("/send_join/{roomID}/{eventID}", common.MakeFedAPI( - "federation_send_join", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) + v1fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( + "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + roomID := vars["roomID"] + eventID := vars["eventID"] + res := SendJoin( + httpReq, request, cfg, rsAPI, keys, roomID, eventID, + ) + // not all responses get wrapped in [code, body] + var body interface{} + body = []interface{}{ + res.Code, res.JSON, } + jerr, ok := res.JSON.(*jsonerror.MatrixError) + if ok { + body = jerr + } + + return util.JSONResponse{ + Headers: res.Headers, + Code: res.Code, + JSON: body, + } + }, + )).Methods(http.MethodPut) + + v2fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( + "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { roomID := vars["roomID"] eventID := vars["eventID"] return SendJoin( - httpReq, request, cfg, query, producer, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, roomID, eventID, ) }, )).Methods(http.MethodPut) - v1fedmux.Handle("/make_leave/{roomID}/{eventID}", common.MakeFedAPI( - "federation_make_leave", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v1fedmux.Handle("/make_leave/{roomID}/{eventID}", httputil.MakeFedAPI( + "federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { roomID := vars["roomID"] eventID := vars["eventID"] return MakeLeave( - httpReq, request, cfg, query, roomID, eventID, + httpReq, request, cfg, rsAPI, roomID, eventID, ) }, )).Methods(http.MethodGet) - v2fedmux.Handle("/send_leave/{roomID}/{eventID}", common.MakeFedAPI( - "federation_send_leave", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } + v2fedmux.Handle("/send_leave/{roomID}/{eventID}", httputil.MakeFedAPI( + "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { roomID := vars["roomID"] eventID := vars["eventID"] return SendLeave( - httpReq, request, cfg, producer, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, roomID, eventID, ) }, )).Methods(http.MethodPut) - v1fedmux.Handle("/version", common.MakeExternalAPI( + v1fedmux.Handle("/version", httputil.MakeExternalAPI( "federation_version", func(httpReq *http.Request) util.JSONResponse { return Version() }, )).Methods(http.MethodGet) - v1fedmux.Handle("/get_missing_events/{roomID}", common.MakeFedAPI( - "federation_get_missing_events", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } - return GetMissingEvents(httpReq, request, query, vars["roomID"]) + v1fedmux.Handle("/get_missing_events/{roomID}", httputil.MakeFedAPI( + "federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + return GetMissingEvents(httpReq, request, rsAPI, vars["roomID"]) }, )).Methods(http.MethodPost) - v1fedmux.Handle("/backfill/{roomID}", common.MakeFedAPI( - "federation_backfill", cfg.Matrix.ServerName, keys, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) - if err != nil { - return util.ErrorResponse(err) - } - return Backfill(httpReq, request, query, vars["roomID"], cfg) + v1fedmux.Handle("/backfill/{roomID}", httputil.MakeFedAPI( + "federation_backfill", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + return Backfill(httpReq, request, rsAPI, vars["roomID"], cfg) }, )).Methods(http.MethodGet) } diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 5a9766f81..53f951951 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -21,11 +21,12 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common/config" + eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) // Send implements /_matrix/federation/v1/send/{txnID} @@ -34,19 +35,19 @@ func Send( request *gomatrixserverlib.FederationRequest, txnID gomatrixserverlib.TransactionID, cfg *config.Dendrite, - query api.RoomserverQueryAPI, - producer *producers.RoomserverProducer, - eduProducer *producers.EDUServerProducer, - keys gomatrixserverlib.KeyRing, + rsAPI api.RoomserverInternalAPI, + eduAPI eduserverAPI.EDUServerInputAPI, + keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { t := txnReq{ - context: httpReq.Context(), - query: query, - producer: producer, - eduProducer: eduProducer, - keys: keys, - federation: federation, + context: httpReq.Context(), + rsAPI: rsAPI, + eduAPI: eduAPI, + keys: keys, + federation: federation, + haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), + newEvents: make(map[string]bool), } var txnEvents struct { @@ -60,6 +61,14 @@ func Send( JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } + // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. + // https://matrix.org/docs/spec/server_server/latest#transactions + if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + } + } // TODO: Really we should have a function to convert FederationRequest to txnReq t.PDUs = txnEvents.PDUs @@ -70,76 +79,100 @@ func Send( util.GetLogger(httpReq.Context()).Infof("Received transaction %q containing %d PDUs, %d EDUs", txnID, len(t.PDUs), len(t.EDUs)) - resp, err := t.processTransaction() - switch err.(type) { - // No error? Great! Send back a 200. - case nil: - return util.JSONResponse{ - Code: http.StatusOK, - JSON: resp, - } - // Handle known error cases as we will return a 400 error for these. - case roomNotFoundError: - case unmarshalError: - case verifySigError: - // Handle unknown error cases. Sending 500 errors back should be a last - // resort as this can make other homeservers back off sending federation - // events. - default: - util.GetLogger(httpReq.Context()).WithError(err).Error("t.processTransaction failed") - return jsonerror.InternalServerError() + resp, jsonErr := t.processTransaction() + if jsonErr != nil { + util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") + return *jsonErr } - // Return a 400 error for bad requests as fallen through from above. + + // https://matrix.org/docs/spec/server_server/r0.1.3#put-matrix-federation-v1-send-txnid + // Status code 200: + // The result of processing the transaction. The server is to use this response + // even in the event of one or more PDUs failing to be processed. return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + Code: http.StatusOK, + JSON: resp, } } type txnReq struct { gomatrixserverlib.Transaction - context context.Context - query api.RoomserverQueryAPI - producer *producers.RoomserverProducer - eduProducer *producers.EDUServerProducer - keys gomatrixserverlib.KeyRing - federation *gomatrixserverlib.FederationClient + context context.Context + rsAPI api.RoomserverInternalAPI + eduAPI eduserverAPI.EDUServerInputAPI + keys gomatrixserverlib.JSONVerifier + federation txnFederationClient + // local cache of events for auth checks, etc - this may include events + // which the roomserver is unaware of. + haveEvents map[string]*gomatrixserverlib.HeaderedEvent + // new events which the roomserver does not know about + newEvents map[string]bool } -func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) { +// A subset of FederationClient functionality that txn requires. Useful for testing. +type txnFederationClient interface { + LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + res gomatrixserverlib.RespState, err error, + ) + LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, + roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) +} + +func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONResponse) { results := make(map[string]gomatrixserverlib.PDUResult) - var pdus []gomatrixserverlib.HeaderedEvent + pdus := []gomatrixserverlib.HeaderedEvent{} for _, pdu := range t.PDUs { var header struct { RoomID string `json:"room_id"` } if err := json.Unmarshal(pdu, &header); err != nil { util.GetLogger(t.context).WithError(err).Warn("Transaction: Failed to extract room ID from event") - return nil, unmarshalError{err} + // We don't know the event ID at this point so we can't return the + // failure in the PDU results + continue } verReq := api.QueryRoomVersionForRoomRequest{RoomID: header.RoomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.query.QueryRoomVersionForRoom(t.context, &verReq, &verRes); err != nil { + if err := t.rsAPI.QueryRoomVersionForRoom(t.context, &verReq, &verRes); err != nil { util.GetLogger(t.context).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) - return nil, roomNotFoundError{verReq.RoomID} + // We don't know the event ID at this point so we can't return the + // failure in the PDU results + continue } event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, verRes.RoomVersion) if err != nil { - util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) - return nil, unmarshalError{err} + if _, ok := err.(gomatrixserverlib.BadJSONError); ok { + // Room version 6 states that homeservers should strictly enforce canonical JSON + // on PDUs. + // + // This enforces that the entire transaction is rejected if a single bad PDU is + // sent. It is unclear if this is the correct behaviour or not. + // + // See https://github.com/matrix-org/synapse/issues/7543 + return nil, &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("PDU contains bad JSON"), + } + } + util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + continue } - if err := gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { + if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) - return nil, verifySigError{event.EventID(), err} + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue } pdus = append(pdus, event.Headered(verRes.RoomVersion)) } // Process the events. for _, e := range pdus { - err := t.processEvent(e.Unwrap()) - if err != nil { + if err := t.processEvent(e.Unwrap(), true); err != nil { // If the error is due to the event itself being bad then we skip // it and move onto the next event. We report an error so that the // sender knows that we have skipped processing it. @@ -155,18 +188,26 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) { // receive another event referencing it. // If we bail and stop processing then we risk wedging incoming // transactions from that server forever. - switch err.(type) { - case roomNotFoundError: - case *gomatrixserverlib.NotAllowed: - default: + if isProcessingErrorFatal(err) { // Any other error should be the result of a temporary error in // our server so we should bail processing the transaction entirely. - return nil, err + util.GetLogger(t.context).Warnf("Processing %s failed fatally: %s", e.EventID(), err) + jsonErr := util.ErrorResponse(err) + return nil, &jsonErr + } else { + // Auth errors mean the event is 'rejected' which have to be silent to appease sytest + _, rejected := err.(*gomatrixserverlib.NotAllowed) + errMsg := err.Error() + if rejected { + errMsg = "" + } + util.GetLogger(t.context).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn( + "Failed to process incoming federation event, skipping", + ) + results[e.EventID()] = gomatrixserverlib.PDUResult{ + Error: errMsg, + } } - results[e.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - util.GetLogger(t.context).WithError(err).WithField("event_id", e.EventID()).Warn("Failed to process incoming federation event, skipping it.") } else { results[e.EventID()] = gomatrixserverlib.PDUResult{} } @@ -177,6 +218,25 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) { return &gomatrixserverlib.RespSend{PDUs: results}, nil } +// isProcessingErrorFatal returns true if the error is really bad and +// we should stop processing the transaction, and returns false if it +// is just some less serious error about a specific event. +func isProcessingErrorFatal(err error) bool { + switch err.(type) { + case roomNotFoundError: + case *gomatrixserverlib.NotAllowed: + case missingPrevEventsError: + default: + switch err { + case context.Canceled: + case context.DeadlineExceeded: + default: + return true + } + } + return false +} + type roomNotFoundError struct { roomID string } @@ -187,12 +247,30 @@ type verifySigError struct { eventID string err error } +type missingPrevEventsError struct { + eventID string + err error +} func (e roomNotFoundError) Error() string { return fmt.Sprintf("room %q not found", e.roomID) } func (e unmarshalError) Error() string { return fmt.Sprintf("unable to parse event: %s", e.err) } func (e verifySigError) Error() string { return fmt.Sprintf("unable to verify signature of event %q: %s", e.eventID, e.err) } +func (e missingPrevEventsError) Error() string { + return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) +} + +func (t *txnReq) haveEventIDs() map[string]bool { + result := make(map[string]bool, len(t.haveEvents)) + for eventID := range t.haveEvents { + if t.newEvents[eventID] { + continue + } + result[eventID] = true + } + return result +} func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { for _, e := range edus { @@ -208,16 +286,35 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal typing event") continue } - if err := t.eduProducer.SendTyping(t.context, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { + if err := eduserverAPI.SendTyping(t.context, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { util.GetLogger(t.context).WithError(err).Error("Failed to send typing event to edu server") } + case gomatrixserverlib.MDirectToDevice: + // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema + var directPayload gomatrixserverlib.ToDeviceMessage + if err := json.Unmarshal(e.Content, &directPayload); err != nil { + util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal send-to-device events") + continue + } + for userID, byUser := range directPayload.Messages { + for deviceID, message := range byUser { + // TODO: check that the user and the device actually exist here + if err := eduserverAPI.SendToDevice(t.context, t.eduAPI, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + util.GetLogger(t.context).WithError(err).WithFields(logrus.Fields{ + "sender": directPayload.Sender, + "user_id": userID, + "device_id": deviceID, + }).Error("Failed to send send-to-device event to edu server") + } + } + } default: util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu") } } } -func (t *txnReq) processEvent(e gomatrixserverlib.Event) error { +func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) error { prevEventIDs := e.PrevEventIDs() // Fetch the state needed to authenticate the event. @@ -228,7 +325,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event) error { StateToFetch: needed.Tuples(), } var stateResp api.QueryStateAfterEventsResponse - if err := t.query.QueryStateAfterEvents(t.context, &stateReq, &stateResp); err != nil { + if err := t.rsAPI.QueryStateAfterEvents(t.context, &stateReq, &stateResp); err != nil { return err } @@ -243,24 +340,17 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event) error { } if !stateResp.PrevEventsExist { - return t.processEventWithMissingState(e, stateResp.RoomVersion) + return t.processEventWithMissingState(e, stateResp.RoomVersion, isInboundTxn) } // Check that the event is allowed by the state at the event. - var events []gomatrixserverlib.Event - for _, headeredEvent := range stateResp.StateEvents { - events = append(events, headeredEvent.Unwrap()) - } - if err := checkAllowedByState(e, events); err != nil { + if err := checkAllowedByState(e, gomatrixserverlib.UnwrapEventHeaders(stateResp.StateEvents)); err != nil { return err } - // TODO: Check that the roomserver has a copy of all of the auth_events. - // TODO: Check that the event is allowed by its auth_events. - // pass the event to the roomserver - _, err := t.producer.SendEvents( - t.context, + _, err := api.SendEvents( + t.context, t.rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(stateResp.RoomVersion), }, @@ -281,7 +371,7 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver return gomatrixserverlib.Allowed(e, &authUsingState) } -func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error { +func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error { // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -296,38 +386,440 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer // event ids and then use /event to fetch the individual events. // However not all version of synapse support /state_ids so you may // need to fallback to /state. - // TODO: Attempt to fill in the gap using /get_missing_events - // TODO: Attempt to fetch the state using /state_ids and /events - state, err := t.federation.LookupState(t.context, t.Origin, e.RoomID(), e.EventID(), roomVersion) + + // Attempt to fill in the gap using /get_missing_events + // This will either: + // - fill in the gap completely then process event `e` returning no backwards extremity + // - fail to fill in the gap and tell us to terminate the transaction err=not nil + // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction + backwardsExtremity, err := t.getMissingEvents(e, roomVersion, isInboundTxn) if err != nil { return err } - // Check that the returned state is valid. - if err := state.Check(t.context, t.keys); err != nil { + if backwardsExtremity == nil { + // we filled in the gap! + return nil + } + + // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. + // security: we have to do state resolution on the new backwards extremity (TODO: WHY) + // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query + // the state AFTER all the prev_events for this event, then mix in our current room state and apply state resolution + // to that to get the state before the event. + var states []*gomatrixserverlib.RespState + needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples() + for _, prevEventID := range backwardsExtremity.PrevEventIDs() { + var prevState *gomatrixserverlib.RespState + prevState, err = t.lookupStateAfterEvent(roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) + if err != nil { + util.GetLogger(t.context).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) + return err + } + states = append(states, prevState) + } + // mix in the current room state + currState, err := t.lookupCurrentState(backwardsExtremity) + if err != nil { + util.GetLogger(t.context).WithError(err).Errorf("Failed to lookup current room state") return err } - // Check that the event is allowed by the state. -retryAllowedState: - if err := checkAllowedByState(e, state.StateEvents); err != nil { - switch missing := err.(type) { - case gomatrixserverlib.MissingAuthEventError: - // An auth event was missing so let's look up that event over federation - for _, s := range state.StateEvents { - if s.EventID() != missing.AuthEventID { - continue - } - err = t.processEventWithMissingState(s, roomVersion) - // If there was no error retrieving the event from federation then - // we assume that it succeeded, so retry the original state check - if err == nil { - goto retryAllowedState - } - } - default: - } + states = append(states, currState) + resolvedState, err := t.resolveStatesAndCheck(roomVersion, states, backwardsExtremity) + if err != nil { + util.GetLogger(t.context).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) return err } - // pass the event along with the state to the roomserver - return t.producer.SendEventWithState(t.context, state, e.Headered(roomVersion)) + // pass the event along with the state to the roomserver using a background context so we don't + // needlessly expire + return api.SendEventWithState(context.Background(), t.rsAPI, resolvedState, e.Headered(roomVersion), t.haveEventIDs()) +} + +// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) +// added into the mix. +func (t *txnReq) lookupStateAfterEvent(roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) { + // try doing all this locally before we resort to querying federation + respState := t.lookupStateAfterEventLocally(roomID, eventID, needed) + if respState != nil { + return respState, nil + } + + respState, err := t.lookupStateBeforeEvent(roomVersion, roomID, eventID) + if err != nil { + return nil, err + } + + // fetch the event we're missing and add it to the pile + h, err := t.lookupEvent(roomVersion, eventID, false) + if err != nil { + return nil, err + } + t.haveEvents[h.EventID()] = h + if h.StateKey() != nil { + addedToState := false + for i := range respState.StateEvents { + se := respState.StateEvents[i] + if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) { + respState.StateEvents[i] = h.Unwrap() + addedToState = true + break + } + } + if !addedToState { + respState.StateEvents = append(respState.StateEvents, h.Unwrap()) + } + } + + return respState, nil +} + +func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState { + var res api.QueryStateAfterEventsResponse + err := t.rsAPI.QueryStateAfterEvents(t.context, &api.QueryStateAfterEventsRequest{ + RoomID: roomID, + PrevEventIDs: []string{eventID}, + StateToFetch: needed, + }, &res) + if err != nil || !res.PrevEventsExist { + util.GetLogger(t.context).WithError(err).Warnf("failed to query state after %s locally", eventID) + return nil + } + for i, ev := range res.StateEvents { + t.haveEvents[ev.EventID()] = &res.StateEvents[i] + } + var authEvents []gomatrixserverlib.Event + missingAuthEvents := make(map[string]bool) + for _, ev := range res.StateEvents { + for _, ae := range ev.AuthEventIDs() { + aev, ok := t.haveEvents[ae] + if ok { + authEvents = append(authEvents, aev.Unwrap()) + } else { + missingAuthEvents[ae] = true + } + } + } + // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't + // have stored the event. + var missingEventList []string + for evID := range missingAuthEvents { + missingEventList = append(missingEventList, evID) + } + queryReq := api.QueryEventsByIDRequest{ + EventIDs: missingEventList, + } + util.GetLogger(t.context).Infof("Fetching missing auth events: %v", missingEventList) + var queryRes api.QueryEventsByIDResponse + if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + return nil + } + for i := range queryRes.Events { + evID := queryRes.Events[i].EventID() + t.haveEvents[evID] = &queryRes.Events[i] + authEvents = append(authEvents, queryRes.Events[i].Unwrap()) + } + + evs := gomatrixserverlib.UnwrapEventHeaders(res.StateEvents) + return &gomatrixserverlib.RespState{ + StateEvents: evs, + AuthEvents: authEvents, + } +} + +func (t *txnReq) lookupCurrentState(newEvent *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { + // Ask the roomserver for information about this room + queryReq := api.QueryLatestEventsAndStateRequest{ + RoomID: newEvent.RoomID(), + StateToFetch: gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*newEvent}).Tuples(), + } + var queryRes api.QueryLatestEventsAndStateResponse + if err := t.rsAPI.QueryLatestEventsAndState(t.context, &queryReq, &queryRes); err != nil { + return nil, fmt.Errorf("lookupCurrentState rsAPI.QueryLatestEventsAndState: %w", err) + } + evs := gomatrixserverlib.UnwrapEventHeaders(queryRes.StateEvents) + return &gomatrixserverlib.RespState{ + StateEvents: evs, + AuthEvents: evs, + }, nil +} + +// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what +// the server supports. +func (t *txnReq) lookupStateBeforeEvent(roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( + respState *gomatrixserverlib.RespState, err error) { + + util.GetLogger(t.context).Infof("lookupStateBeforeEvent %s", eventID) + + // Attempt to fetch the missing state using /state_ids and /events + respState, err = t.lookupMissingStateViaStateIDs(roomID, eventID, roomVersion) + if err != nil { + // Fallback to /state + util.GetLogger(t.context).WithError(err).Warn("lookupStateBeforeEvent failed to /state_ids, falling back to /state") + respState, err = t.lookupMissingStateViaState(roomID, eventID, roomVersion) + } + return +} + +func (t *txnReq) resolveStatesAndCheck(roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { + var authEventList []gomatrixserverlib.Event + var stateEventList []gomatrixserverlib.Event + for _, state := range states { + authEventList = append(authEventList, state.AuthEvents...) + stateEventList = append(stateEventList, state.StateEvents...) + } + resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) + if err != nil { + return nil, err + } + // apply the current event +retryAllowedState: + if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil { + switch missing := err.(type) { + case gomatrixserverlib.MissingAuthEventError: + h, err2 := t.lookupEvent(roomVersion, missing.AuthEventID, true) + if err2 != nil { + return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) + } + util.GetLogger(t.context).Infof("fetched event %s", missing.AuthEventID) + resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) + goto retryAllowedState + default: + } + return nil, err + } + return &gomatrixserverlib.RespState{ + AuthEvents: authEventList, + StateEvents: resolvedStateEvents, + }, nil +} + +// getMissingEvents returns a nil backwardsExtremity if missing events were fetched and handled, else returns the new backwards extremity which we should +// begin from. Returns an error only if we should terminate the transaction which initiated /get_missing_events +// This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. +// This means that we may recursively call this function, as we spider back up prev_events to the min depth. +func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { + if !isInboundTxn { + // we've recursed here, so just take a state snapshot please! + return &e, nil + } + logger := util.GetLogger(t.context).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) + needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) + // query latest events (our trusted forward extremities) + req := api.QueryLatestEventsAndStateRequest{ + RoomID: e.RoomID(), + StateToFetch: needed.Tuples(), + } + var res api.QueryLatestEventsAndStateResponse + if err = t.rsAPI.QueryLatestEventsAndState(t.context, &req, &res); err != nil { + logger.WithError(err).Warn("Failed to query latest events") + return &e, nil + } + latestEvents := make([]string, len(res.LatestEvents)) + for i := range res.LatestEvents { + latestEvents[i] = res.LatestEvents[i].EventID + } + // this server just sent us an event for which we do not know its prev_events - ask that server for those prev_events. + minDepth := int(res.Depth) - 20 + if minDepth < 0 { + minDepth = 0 + } + missingResp, err := t.federation.LookupMissingEvents(t.context, t.Origin, e.RoomID(), gomatrixserverlib.MissingEvents{ + Limit: 20, + // synapse uses the min depth they've ever seen in that room + MinDepth: minDepth, + // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. + EarliestEvents: latestEvents, + // The event IDs to retrieve the previous events for. + LatestEvents: []string{e.EventID()}, + }, roomVersion) + + // security: how we handle failures depends on whether or not this event will become the new forward extremity for the room. + // There's 2 scenarios to consider: + // - Case A: We got pushed an event and are now fetching missing prev_events. (isInboundTxn=true) + // - Case B: We are fetching missing prev_events already and now fetching some more (isInboundTxn=false) + // In Case B, we know for sure that the event we are currently processing will not become the new forward extremity for the room, + // as it was called in response to an inbound txn which had it as a prev_event. + // In Case A, the event is a forward extremity, and could eventually become the _only_ forward extremity in the room. This is bad + // because it means we would trust the state at that event to be the state for the entire room, and allows rooms to be hijacked. + // https://github.com/matrix-org/synapse/pull/3456 + // https://github.com/matrix-org/synapse/blob/229eb81498b0fe1da81e9b5b333a0285acde9446/synapse/handlers/federation.py#L335 + // For now, we do not allow Case B, so reject the event. + if err != nil { + logger.WithError(err).Errorf( + "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", + t.Origin, + ) + return nil, missingPrevEventsError{ + eventID: e.EventID(), + err: err, + } + } + logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) + + // topologically sort and sanity check that we are making forward progress + newEvents := gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) + shouldHaveSomeEventIDs := e.PrevEventIDs() + hasPrevEvent := false +Event: + for _, pe := range shouldHaveSomeEventIDs { + for _, ev := range newEvents { + if ev.EventID() == pe { + hasPrevEvent = true + break Event + } + } + } + if !hasPrevEvent { + err = fmt.Errorf("called /get_missing_events but server %s didn't return any prev_events with IDs %v", t.Origin, shouldHaveSomeEventIDs) + logger.WithError(err).Errorf( + "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", + t.Origin, + ) + return nil, missingPrevEventsError{ + eventID: e.EventID(), + err: err, + } + } + // process the missing events then the event which started this whole thing + for _, ev := range append(newEvents, e) { + err := t.processEvent(ev, false) + if err != nil { + return nil, err + } + } + + // we processed everything! + return nil, nil +} + +func (t *txnReq) lookupMissingStateViaState(roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + respState *gomatrixserverlib.RespState, err error) { + state, err := t.federation.LookupState(t.context, t.Origin, roomID, eventID, roomVersion) + if err != nil { + return nil, err + } + // Check that the returned state is valid. + if err := state.Check(t.context, t.keys, nil); err != nil { + return nil, err + } + return &state, nil +} + +func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + *gomatrixserverlib.RespState, error) { + util.GetLogger(t.context).Infof("lookupMissingStateViaStateIDs %s", eventID) + // fetch the state event IDs at the time of the event + stateIDs, err := t.federation.LookupStateIDs(t.context, t.Origin, roomID, eventID) + if err != nil { + return nil, err + } + // work out which auth/state IDs are missing + wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) + missing := make(map[string]bool) + var missingEventList []string + for _, sid := range wantIDs { + if _, ok := t.haveEvents[sid]; !ok { + if !missing[sid] { + missing[sid] = true + missingEventList = append(missingEventList, sid) + } + } + } + + // fetch as many as we can from the roomserver + queryReq := api.QueryEventsByIDRequest{ + EventIDs: missingEventList, + } + var queryRes api.QueryEventsByIDResponse + if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + return nil, err + } + for i := range queryRes.Events { + evID := queryRes.Events[i].EventID() + t.haveEvents[evID] = &queryRes.Events[i] + if missing[evID] { + delete(missing, evID) + } + } + + util.GetLogger(t.context).WithFields(logrus.Fields{ + "missing": len(missing), + "event_id": eventID, + "room_id": roomID, + "total_state": len(stateIDs.StateEventIDs), + "total_auth_events": len(stateIDs.AuthEventIDs), + }).Info("Fetching missing state at event") + + for missingEventID := range missing { + var h *gomatrixserverlib.HeaderedEvent + h, err = t.lookupEvent(roomVersion, missingEventID, false) + if err != nil { + return nil, err + } + t.haveEvents[h.EventID()] = h + } + resp, err := t.createRespStateFromStateIDs(stateIDs) + return resp, err +} + +func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( + *gomatrixserverlib.RespState, error) { + // create a RespState response using the response to /state_ids as a guide + respState := gomatrixserverlib.RespState{ + AuthEvents: make([]gomatrixserverlib.Event, len(stateIDs.AuthEventIDs)), + StateEvents: make([]gomatrixserverlib.Event, len(stateIDs.StateEventIDs)), + } + + for i := range stateIDs.StateEventIDs { + ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] + if !ok { + return nil, fmt.Errorf("missing state event %s", stateIDs.StateEventIDs[i]) + } + respState.StateEvents[i] = ev.Unwrap() + } + for i := range stateIDs.AuthEventIDs { + ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] + if !ok { + return nil, fmt.Errorf("missing auth event %s", stateIDs.AuthEventIDs[i]) + } + respState.AuthEvents[i] = ev.Unwrap() + } + // We purposefully do not do auth checks on the returned events, as they will still + // be processed in the exact same way, just as a 'rejected' event + // TODO: Add a field to HeaderedEvent to indicate if the event is rejected. + return &respState, nil +} + +func (t *txnReq) lookupEvent(roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { + if localFirst { + // fetch from the roomserver + queryReq := api.QueryEventsByIDRequest{ + EventIDs: []string{missingEventID}, + } + var queryRes api.QueryEventsByIDResponse + if err := t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + util.GetLogger(t.context).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) + } else if len(queryRes.Events) == 1 { + return &queryRes.Events[0], nil + } + } + txn, err := t.federation.GetEvent(t.context, t.Origin, missingEventID) + if err != nil || len(txn.PDUs) == 0 { + util.GetLogger(t.context).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID") + return nil, err + } + pdu := txn.PDUs[0] + var event gomatrixserverlib.Event + event, err = gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + if err != nil { + util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) + return nil, unmarshalError{err} + } + if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { + util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + return nil, verifySigError{event.EventID(), err} + } + h := event.Headered(roomVersion) + t.newEvents[h.EventID()] = true + return &h, nil } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go new file mode 100644 index 000000000..3f5d5f4e0 --- /dev/null +++ b/federationapi/routing/send_test.go @@ -0,0 +1,665 @@ +package routing + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + eduAPI "github.com/matrix-org/dendrite/eduserver/api" + fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testDestination = gomatrixserverlib.ServerName("white.orchard") +) + +var ( + testRoomVersion = gomatrixserverlib.RoomVersionV1 + testData = []json.RawMessage{ + []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), + // messages + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), + } + testEvents = []gomatrixserverlib.HeaderedEvent{} + testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]gomatrixserverlib.HeaderedEvent) +) + +func init() { + for _, j := range testData { + e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) + if err != nil { + panic("cannot load test data: " + err.Error()) + } + h := e.Headered(testRoomVersion) + testEvents = append(testEvents, h) + if e.StateKey() != nil { + testStateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = h + } + } +} + +type testEDUProducer struct { + // this producer keeps track of calls to InputTypingEvent + invocations []eduAPI.InputTypingEventRequest +} + +func (p *testEDUProducer) InputTypingEvent( + ctx context.Context, + request *eduAPI.InputTypingEventRequest, + response *eduAPI.InputTypingEventResponse, +) error { + p.invocations = append(p.invocations, *request) + return nil +} + +func (p *testEDUProducer) InputSendToDeviceEvent( + ctx context.Context, + request *eduAPI.InputSendToDeviceEventRequest, + response *eduAPI.InputSendToDeviceEventResponse, +) error { + return nil +} + +type testRoomserverAPI struct { + inputRoomEvents []api.InputRoomEvent + queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse + queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse + queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse +} + +func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} + +func (t *testRoomserverAPI) InputRoomEvents( + ctx context.Context, + request *api.InputRoomEventsRequest, + response *api.InputRoomEventsResponse, +) error { + t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) + for _, ire := range request.InputRoomEvents { + fmt.Println("InputRoomEvents: ", ire.Event.EventID()) + } + return nil +} + +func (t *testRoomserverAPI) PerformInvite( + ctx context.Context, + req *api.PerformInviteRequest, + res *api.PerformInviteResponse, +) { +} + +func (t *testRoomserverAPI) PerformJoin( + ctx context.Context, + req *api.PerformJoinRequest, + res *api.PerformJoinResponse, +) { +} + +func (t *testRoomserverAPI) PerformLeave( + ctx context.Context, + req *api.PerformLeaveRequest, + res *api.PerformLeaveResponse, +) error { + return nil +} + +// Query the latest events and state for a room from the room server. +func (t *testRoomserverAPI) QueryLatestEventsAndState( + ctx context.Context, + request *api.QueryLatestEventsAndStateRequest, + response *api.QueryLatestEventsAndStateResponse, +) error { + r := t.queryLatestEventsAndState(request) + response.RoomExists = r.RoomExists + response.RoomVersion = testRoomVersion + response.LatestEvents = r.LatestEvents + response.StateEvents = r.StateEvents + response.Depth = r.Depth + return nil +} + +// Query the state after a list of events in a room from the room server. +func (t *testRoomserverAPI) QueryStateAfterEvents( + ctx context.Context, + request *api.QueryStateAfterEventsRequest, + response *api.QueryStateAfterEventsResponse, +) error { + response.RoomVersion = testRoomVersion + res := t.queryStateAfterEvents(request) + response.PrevEventsExist = res.PrevEventsExist + response.RoomExists = res.RoomExists + response.StateEvents = res.StateEvents + return nil +} + +// Query a list of events by event ID. +func (t *testRoomserverAPI) QueryEventsByID( + ctx context.Context, + request *api.QueryEventsByIDRequest, + response *api.QueryEventsByIDResponse, +) error { + res := t.queryEventsByID(request) + response.Events = res.Events + return nil +} + +// Query the membership event for an user for a room. +func (t *testRoomserverAPI) QueryMembershipForUser( + ctx context.Context, + request *api.QueryMembershipForUserRequest, + response *api.QueryMembershipForUserResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Query a list of membership events for a room +func (t *testRoomserverAPI) QueryMembershipsForRoom( + ctx context.Context, + request *api.QueryMembershipsForRoomRequest, + response *api.QueryMembershipsForRoomResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Query whether a server is allowed to see an event +func (t *testRoomserverAPI) QueryServerAllowedToSeeEvent( + ctx context.Context, + request *api.QueryServerAllowedToSeeEventRequest, + response *api.QueryServerAllowedToSeeEventResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Query missing events for a room from roomserver +func (t *testRoomserverAPI) QueryMissingEvents( + ctx context.Context, + request *api.QueryMissingEventsRequest, + response *api.QueryMissingEventsResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Query to get state and auth chain for a (potentially hypothetical) event. +// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate +// the state and auth chain to return. +func (t *testRoomserverAPI) QueryStateAndAuthChain( + ctx context.Context, + request *api.QueryStateAndAuthChainRequest, + response *api.QueryStateAndAuthChainResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Query a given amount (or less) of events prior to a given set of events. +func (t *testRoomserverAPI) PerformBackfill( + ctx context.Context, + request *api.PerformBackfillRequest, + response *api.PerformBackfillResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Asks for the default room version as preferred by the server. +func (t *testRoomserverAPI) QueryRoomVersionCapabilities( + ctx context.Context, + request *api.QueryRoomVersionCapabilitiesRequest, + response *api.QueryRoomVersionCapabilitiesResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Asks for the room version for a given room. +func (t *testRoomserverAPI) QueryRoomVersionForRoom( + ctx context.Context, + request *api.QueryRoomVersionForRoomRequest, + response *api.QueryRoomVersionForRoomResponse, +) error { + response.RoomVersion = testRoomVersion + return nil +} + +// Set a room alias +func (t *testRoomserverAPI) SetRoomAlias( + ctx context.Context, + req *api.SetRoomAliasRequest, + response *api.SetRoomAliasResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Get the room ID for an alias +func (t *testRoomserverAPI) GetRoomIDForAlias( + ctx context.Context, + req *api.GetRoomIDForAliasRequest, + response *api.GetRoomIDForAliasResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Get all known aliases for a room ID +func (t *testRoomserverAPI) GetAliasesForRoomID( + ctx context.Context, + req *api.GetAliasesForRoomIDRequest, + response *api.GetAliasesForRoomIDResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Get the user ID of the creator of an alias +func (t *testRoomserverAPI) GetCreatorIDForAlias( + ctx context.Context, + req *api.GetCreatorIDForAliasRequest, + response *api.GetCreatorIDForAliasResponse, +) error { + return fmt.Errorf("not implemented") +} + +// Remove a room alias +func (t *testRoomserverAPI) RemoveRoomAlias( + ctx context.Context, + req *api.RemoveRoomAliasRequest, + response *api.RemoveRoomAliasResponse, +) error { + return fmt.Errorf("not implemented") +} + +type txnFedClient struct { + state map[string]gomatrixserverlib.RespState // event_id to response + stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response + getEvent map[string]gomatrixserverlib.Transaction // event_id to response + getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) +} + +func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + res gomatrixserverlib.RespState, err error, +) { + fmt.Println("testFederationClient.LookupState", eventID) + r, ok := c.state[eventID] + if !ok { + err = fmt.Errorf("txnFedClient: no /state for event %s", eventID) + return + } + res = r + return +} +func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { + fmt.Println("testFederationClient.LookupStateIDs", eventID) + r, ok := c.stateIDs[eventID] + if !ok { + err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID) + return + } + res = r + return +} +func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { + fmt.Println("testFederationClient.GetEvent", eventID) + r, ok := c.getEvent[eventID] + if !ok { + err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID) + return + } + res = r + return +} +func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, + roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { + return c.getMissingEvents(missing) +} + +func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { + t := &txnReq{ + context: context.Background(), + rsAPI: rsAPI, + eduAPI: &testEDUProducer{}, + keys: &test.NopJSONVerifier{}, + federation: fedClient, + haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), + newEvents: make(map[string]bool), + } + t.PDUs = pdus + t.Origin = testOrigin + t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + t.Destination = testDestination + return t +} + +func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) { + res, err := txn.processTransaction() + if err != nil { + t.Errorf("txn.processTransaction returned an error: %v", err) + return + } + if len(res.PDUs) != len(txn.PDUs) { + t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) + return + } +NextPDU: + for eventID, result := range res.PDUs { + if result.Error == "" { + continue + } + for _, eventIDWantError := range pdusWithErrors { + if eventID == eventIDWantError { + break NextPDU + } + } + t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) + } +} + +func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []gomatrixserverlib.HeaderedEvent) { +NextTuple: + for _, t := range tuples { + for _, o := range omitTuples { + if t == o { + break NextTuple + } + } + h, ok := testStateEvents[t] + if ok { + result = append(result, h) + } + } + return +} + +func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []gomatrixserverlib.HeaderedEvent) { + for _, g := range got { + fmt.Println("GOT ", g.Event.EventID()) + } + if len(got) != len(want) { + t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) + return + } + for i := range got { + if got[i].Event.EventID() != want[i].EventID() { + t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) + } + } +} + +// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on +// to the roomserver. It's the most basic test possible. +func TestBasicTransaction(t *testing.T) { + rsAPI := &testRoomserverAPI{ + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + PrevEventsExist: true, + RoomExists: true, + StateEvents: fromStateTuples(req.StateToFetch, nil), + } + }, + } + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} + +// The purpose of this test is to check that if the event received fails auth checks the transaction is failed. +func TestTransactionFailAuthChecks(t *testing.T) { + rsAPI := &testRoomserverAPI{ + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + PrevEventsExist: true, + RoomExists: true, + // omit the create event so auth checks fail + StateEvents: fromStateTuples(req.StateToFetch, []gomatrixserverlib.StateKeyTuple{ + {EventType: gomatrixserverlib.MRoomCreate, StateKey: ""}, + }), + } + }, + } + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) + mustProcessTransaction(t, txn, []string{ + // expect the event to have an error + testEvents[len(testEvents)-1].EventID(), + }) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, nil) // expect no messages to be sent to the roomserver +} + +// The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, +// we request them from /get_missing_events. It works by setting PrevEventsExist=false in the roomserver query response, +// resulting in a call to /get_missing_events which returns the missing prev event. Both events should be processed in +// topological order and sent to the roomserver. +func TestTransactionFetchMissingPrevEvents(t *testing.T) { + haveEvent := testEvents[len(testEvents)-3] + prevEvent := testEvents[len(testEvents)-2] + inputEvent := testEvents[len(testEvents)-1] + + var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions + rsAPI = &testRoomserverAPI{ + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + // we expect this to be called three times: + // - first with input event to realise there's a gap + // - second with the prevEvent to realise there is no gap + // - third with the input event to realise there is no longer a gap + prevEventsExist := false + if len(req.PrevEventIDs) == 1 { + switch req.PrevEventIDs[0] { + case haveEvent.EventID(): + prevEventsExist = true + case prevEvent.EventID(): + // we only have this event if we've been send prevEvent + if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { + prevEventsExist = true + } + } + } + + return api.QueryStateAfterEventsResponse{ + PrevEventsExist: prevEventsExist, + RoomExists: true, + StateEvents: fromStateTuples(req.StateToFetch, nil), + } + }, + queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { + return api.QueryLatestEventsAndStateResponse{ + RoomExists: true, + Depth: haveEvent.Depth(), + LatestEvents: []gomatrixserverlib.EventReference{ + haveEvent.EventReference(), + }, + StateEvents: fromStateTuples(req.StateToFetch, nil), + } + }, + } + + cli := &txnFedClient{ + getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { + if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) { + t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, haveEvent.EventID()) + } + if !reflect.DeepEqual(missing.LatestEvents, []string{inputEvent.EventID()}) { + t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, inputEvent.EventID()) + } + return gomatrixserverlib.RespMissingEvents{ + Events: []gomatrixserverlib.Event{ + prevEvent.Unwrap(), + }, + }, nil + }, + } + + pdus := []json.RawMessage{ + inputEvent.JSON(), + } + txn := mustCreateTransaction(rsAPI, cli, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) +} + +// The purpose of this test is to check that when there are missing prev_events and we still haven't been able to fill +// in the hole with /get_missing_events that the state BEFORE the events we want to persist is fetched via /state_ids +// and /event. It works by setting PrevEventsExist=false in the roomserver query response, resulting in +// a call to /get_missing_events which returns 1 out of the 2 events it needs to fill in the gap. Synapse and Dendrite +// both give up after 1x /get_missing_events call, relying on requesting the state AFTER the missing event in order to +// continue. The DAG looks something like: +// FE GME TXN +// A ---> B ---> C ---> D +// TXN=event in the txn, GME=response to /get_missing_events, FE=roomserver's forward extremity. Should result in: +// - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event. +// - state resolution is done to check C is allowed. +// This results in B being sent as an outlier FIRST, then C,D. +func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { + eventA := testEvents[len(testEvents)-5] + // this is also len(testEvents)-4 + eventB := testStateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomPowerLevels, + StateKey: "", + }] + eventC := testEvents[len(testEvents)-3] + eventD := testEvents[len(testEvents)-2] + fmt.Println("a:", eventA.EventID()) + fmt.Println("b:", eventB.EventID()) + fmt.Println("c:", eventC.EventID()) + fmt.Println("d:", eventD.EventID()) + var rsAPI *testRoomserverAPI + rsAPI = &testRoomserverAPI{ + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + omitTuples := []gomatrixserverlib.StateKeyTuple{ + { + EventType: gomatrixserverlib.MRoomPowerLevels, + StateKey: "", + }, + } + askingForEvent := req.PrevEventIDs[0] + haveEventB := false + haveEventC := false + for _, ev := range rsAPI.inputRoomEvents { + switch ev.Event.EventID() { + case eventB.EventID(): + haveEventB = true + omitTuples = nil // include event B now + case eventC.EventID(): + haveEventC = true + } + } + prevEventExists := false + if askingForEvent == eventC.EventID() { + prevEventExists = haveEventC + } else if askingForEvent == eventB.EventID() { + prevEventExists = haveEventB + } + var stateEvents []gomatrixserverlib.HeaderedEvent + if prevEventExists { + stateEvents = fromStateTuples(req.StateToFetch, omitTuples) + } + return api.QueryStateAfterEventsResponse{ + PrevEventsExist: prevEventExists, + RoomExists: true, + StateEvents: stateEvents, + } + }, + queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { + omitTuples := []gomatrixserverlib.StateKeyTuple{ + {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}, + } + return api.QueryLatestEventsAndStateResponse{ + RoomExists: true, + Depth: eventA.Depth(), + LatestEvents: []gomatrixserverlib.EventReference{ + eventA.EventReference(), + }, + StateEvents: fromStateTuples(req.StateToFetch, omitTuples), + } + }, + queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { + var res api.QueryEventsByIDResponse + fmt.Println("queryEventsByID ", req.EventIDs) + for _, wantEventID := range req.EventIDs { + for _, ev := range testStateEvents { + // roomserver is missing the power levels event unless it's been sent to us recently as an outlier + if wantEventID == eventB.EventID() { + fmt.Println("Asked for pl event") + for _, inEv := range rsAPI.inputRoomEvents { + fmt.Println("recv ", inEv.Event.EventID()) + if inEv.Event.EventID() == wantEventID { + res.Events = append(res.Events, inEv.Event) + break + } + } + continue + } + if ev.EventID() == wantEventID { + res.Events = append(res.Events, ev) + } + } + } + return res + }, + } + // /state_ids for event B returns every state event but B (it's the state before) + var authEventIDs []string + var stateEventIDs []string + for _, ev := range testStateEvents { + if ev.EventID() == eventB.EventID() { + continue + } + // state res checks what auth events you give it, and this isn't a valid auth event + if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility { + authEventIDs = append(authEventIDs, ev.EventID()) + } + stateEventIDs = append(stateEventIDs, ev.EventID()) + } + cli := &txnFedClient{ + stateIDs: map[string]gomatrixserverlib.RespStateIDs{ + eventB.EventID(): { + StateEventIDs: stateEventIDs, + AuthEventIDs: authEventIDs, + }, + }, + // /event for event B returns it + getEvent: map[string]gomatrixserverlib.Transaction{ + eventB.EventID(): { + PDUs: []json.RawMessage{ + eventB.JSON(), + }, + }, + }, + // /get_missing_events should be done exactly once + getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { + if !reflect.DeepEqual(missing.EarliestEvents, []string{eventA.EventID()}) { + t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, eventA.EventID()) + } + if !reflect.DeepEqual(missing.LatestEvents, []string{eventD.EventID()}) { + t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, eventD.EventID()) + } + // just return event C, not event B so /state_ids logic kicks in as there will STILL be missing prev_events + return gomatrixserverlib.RespMissingEvents{ + Events: []gomatrixserverlib.Event{ + eventC.Unwrap(), + }, + }, nil + }, + } + + pdus := []json.RawMessage{ + eventD.JSON(), + } + txn := mustCreateTransaction(rsAPI, cli, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) +} diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 548598dd7..28dfad846 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -27,7 +27,7 @@ import ( func GetState( ctx context.Context, request *gomatrixserverlib.FederationRequest, - query api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, roomID string, ) util.JSONResponse { eventID, err := parseEventIDParam(request) @@ -35,7 +35,7 @@ func GetState( return *err } - state, err := getState(ctx, request, query, roomID, eventID) + state, err := getState(ctx, request, rsAPI, roomID, eventID) if err != nil { return *err } @@ -47,7 +47,7 @@ func GetState( func GetStateIDs( ctx context.Context, request *gomatrixserverlib.FederationRequest, - query api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, roomID string, ) util.JSONResponse { eventID, err := parseEventIDParam(request) @@ -55,7 +55,7 @@ func GetStateIDs( return *err } - state, err := getState(ctx, request, query, roomID, eventID) + state, err := getState(ctx, request, rsAPI, roomID, eventID) if err != nil { return *err } @@ -94,28 +94,30 @@ func parseEventIDParam( func getState( ctx context.Context, request *gomatrixserverlib.FederationRequest, - query api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, roomID string, eventID string, ) (*gomatrixserverlib.RespState, *util.JSONResponse) { - event, resErr := getEvent(ctx, request, query, eventID) + event, resErr := fetchEvent(ctx, rsAPI, eventID) if resErr != nil { return nil, resErr } if event.RoomID() != roomID { - return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} + return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} + } + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) + if resErr != nil { + return nil, resErr } - authEventIDs := getIDsFromEventRef(event.AuthEvents()) - var response api.QueryStateAndAuthChainResponse - err := query.QueryStateAndAuthChain( + err := rsAPI.QueryStateAndAuthChain( ctx, &api.QueryStateAndAuthChainRequest{ RoomID: roomID, PrevEventIDs: []string{eventID}, - AuthEventIDs: authEventIDs, + AuthEventIDs: event.AuthEventIDs(), }, &response, ) @@ -134,15 +136,6 @@ func getState( }, nil } -func getIDsFromEventRef(events []gomatrixserverlib.EventReference) []string { - IDs := make([]string, len(events)) - for i := range events { - IDs[i] = events[i].EventID - } - - return IDs -} - func getIDsFromEvent(events []gomatrixserverlib.Event) []string { IDs := make([]string, len(events)) for i := range events { diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index f93d934ed..61788010b 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -21,14 +21,11 @@ import ( "net/http" "time" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" - roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -58,10 +55,10 @@ var ( // CreateInvitesFrom3PIDInvites implements POST /_matrix/federation/v1/3pid/onbind func CreateInvitesFrom3PIDInvites( - req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI, - asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite, - producer *producers.RoomserverProducer, federation *gomatrixserverlib.FederationClient, - accountDB accounts.Database, + req *http.Request, rsAPI api.RoomserverInternalAPI, + cfg *config.Dendrite, + federation *gomatrixserverlib.FederationClient, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { var body invites if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { @@ -72,7 +69,7 @@ func CreateInvitesFrom3PIDInvites( for _, inv := range body.Invites { verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := queryAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil { + if err := rsAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.UnsupportedRoomVersion(err.Error()), @@ -80,7 +77,7 @@ func CreateInvitesFrom3PIDInvites( } event, err := createInviteFrom3PIDInvite( - req.Context(), queryAPI, asAPI, cfg, inv, federation, accountDB, + req.Context(), rsAPI, cfg, inv, federation, userAPI, ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("createInviteFrom3PIDInvite failed") @@ -92,8 +89,8 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if _, err := producer.SendEvents(req.Context(), evs, cfg.Matrix.ServerName, nil); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed") + if _, err := api.SendEvents(req.Context(), rsAPI, evs, cfg.Matrix.ServerName, nil); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -108,10 +105,9 @@ func ExchangeThirdPartyInvite( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, roomID string, - queryAPI roomserverAPI.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, cfg *config.Dendrite, federation *gomatrixserverlib.FederationClient, - producer *producers.RoomserverProducer, ) util.JSONResponse { var builder gomatrixserverlib.EventBuilder if err := json.Unmarshal(request.Content(), &builder); err != nil { @@ -148,7 +144,7 @@ func ExchangeThirdPartyInvite( verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err = queryAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + if err = rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.UnsupportedRoomVersion(err.Error()), @@ -156,7 +152,7 @@ func ExchangeThirdPartyInvite( } // Auth and build the event from what the remote server sent us - event, err := buildMembershipEvent(httpReq.Context(), &builder, queryAPI, cfg) + event, err := buildMembershipEvent(httpReq.Context(), &builder, rsAPI, cfg) if err == errNotInRoom { return util.JSONResponse{ Code: http.StatusNotFound, @@ -176,15 +172,15 @@ func ExchangeThirdPartyInvite( } // Send the event to the roomserver - if _, err = producer.SendEvents( - httpReq.Context(), + if _, err = api.SendEvents( + httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ signedEvent.Event.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, nil, ); err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendEvents failed") + util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -199,14 +195,14 @@ func ExchangeThirdPartyInvite( // Returns an error if there was a problem building the event or fetching the // necessary data to do so. func createInviteFrom3PIDInvite( - ctx context.Context, queryAPI roomserverAPI.RoomserverQueryAPI, - asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite, + ctx context.Context, rsAPI api.RoomserverInternalAPI, + cfg *config.Dendrite, inv invite, federation *gomatrixserverlib.FederationClient, - accountDB accounts.Database, + userAPI userapi.UserInternalAPI, ) (*gomatrixserverlib.Event, error) { verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := queryAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + if err := rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { return nil, err } @@ -227,14 +223,17 @@ func createInviteFrom3PIDInvite( StateKey: &inv.MXID, } - profile, err := appserviceAPI.RetrieveUserProfile(ctx, inv.MXID, asAPI, accountDB) + var res userapi.QueryProfileResponse + err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{ + UserID: inv.MXID, + }, &res) if err != nil { return nil, err } content := gomatrixserverlib.MemberContent{ - AvatarURL: profile.AvatarURL, - DisplayName: profile.DisplayName, + AvatarURL: res.AvatarURL, + DisplayName: res.DisplayName, Membership: gomatrixserverlib.Invite, ThirdPartyInvite: &gomatrixserverlib.MemberThirdPartyInvite{ Signed: inv.Signed, @@ -245,7 +244,7 @@ func createInviteFrom3PIDInvite( return nil, err } - event, err := buildMembershipEvent(ctx, builder, queryAPI, cfg) + event, err := buildMembershipEvent(ctx, builder, rsAPI, cfg) if err == errNotInRoom { return nil, sendToRemoteServer(ctx, inv, federation, cfg, *builder) } @@ -263,7 +262,7 @@ func createInviteFrom3PIDInvite( // Returns an error if something failed during the process. func buildMembershipEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, queryAPI roomserverAPI.RoomserverQueryAPI, + builder *gomatrixserverlib.EventBuilder, rsAPI api.RoomserverInternalAPI, cfg *config.Dendrite, ) (*gomatrixserverlib.Event, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) @@ -276,12 +275,12 @@ func buildMembershipEvent( } // Ask the roomserver for information about this room - queryReq := roomserverAPI.QueryLatestEventsAndStateRequest{ + queryReq := api.QueryLatestEventsAndStateRequest{ RoomID: builder.RoomID, StateToFetch: eventsNeeded.Tuples(), } - var queryRes roomserverAPI.QueryLatestEventsAndStateResponse - if err = queryAPI.QueryLatestEventsAndState(ctx, &queryReq, &queryRes); err != nil { + var queryRes api.QueryLatestEventsAndStateResponse + if err = rsAPI.QueryLatestEventsAndState(ctx, &queryReq, &queryRes); err != nil { return nil, err } diff --git a/federationsender/api/api.go b/federationsender/api/api.go new file mode 100644 index 000000000..02c762582 --- /dev/null +++ b/federationsender/api/api.go @@ -0,0 +1,91 @@ +package api + +import ( + "context" + + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// FederationSenderInternalAPI is used to query information from the federation sender. +type FederationSenderInternalAPI interface { + // PerformDirectoryLookup looks up a remote room ID from a room alias. + PerformDirectoryLookup( + ctx context.Context, + request *PerformDirectoryLookupRequest, + response *PerformDirectoryLookupResponse, + ) error + // Query the server names of the joined hosts in a room. + // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice + // containing only the server names (without information for membership events). + QueryJoinedHostServerNamesInRoom( + ctx context.Context, + request *QueryJoinedHostServerNamesInRoomRequest, + response *QueryJoinedHostServerNamesInRoomResponse, + ) error + // Handle an instruction to make_join & send_join with a remote server. + PerformJoin( + ctx context.Context, + request *PerformJoinRequest, + response *PerformJoinResponse, + ) error + // Handle an instruction to make_leave & send_leave with a remote server. + PerformLeave( + ctx context.Context, + request *PerformLeaveRequest, + response *PerformLeaveResponse, + ) error + // Notifies the federation sender that these servers may be online and to retry sending messages. + PerformServersAlive( + ctx context.Context, + request *PerformServersAliveRequest, + response *PerformServersAliveResponse, + ) error +} + +type PerformDirectoryLookupRequest struct { + RoomAlias string `json:"room_alias"` + ServerName gomatrixserverlib.ServerName `json:"server_name"` +} + +type PerformDirectoryLookupResponse struct { + RoomID string `json:"room_id"` + ServerNames []gomatrixserverlib.ServerName `json:"server_names"` +} + +type PerformJoinRequest struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + // The sorted list of servers to try. Servers will be tried sequentially, after de-duplication. + ServerNames types.ServerNames `json:"server_names"` + Content map[string]interface{} `json:"content"` +} + +type PerformJoinResponse struct { +} + +type PerformLeaveRequest struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + ServerNames types.ServerNames `json:"server_names"` +} + +type PerformLeaveResponse struct { +} + +type PerformServersAliveRequest struct { + Servers []gomatrixserverlib.ServerName +} + +type PerformServersAliveResponse struct { +} + +// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames +type QueryJoinedHostServerNamesInRoomRequest struct { + RoomID string `json:"room_id"` +} + +// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames +type QueryJoinedHostServerNamesInRoomResponse struct { + ServerNames []gomatrixserverlib.ServerName `json:"server_names"` +} diff --git a/federationsender/api/query.go b/federationsender/api/query.go deleted file mode 100644 index 7c0ca7ff2..000000000 --- a/federationsender/api/query.go +++ /dev/null @@ -1,99 +0,0 @@ -package api - -import ( - "context" - "errors" - "net/http" - - commonHTTP "github.com/matrix-org/dendrite/common/http" - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/federationsender/types" - "github.com/opentracing/opentracing-go" -) - -// QueryJoinedHostsInRoomRequest is a request to QueryJoinedHostsInRoom -type QueryJoinedHostsInRoomRequest struct { - RoomID string `json:"room_id"` -} - -// QueryJoinedHostsInRoomResponse is a response to QueryJoinedHostsInRoom -type QueryJoinedHostsInRoomResponse struct { - JoinedHosts []types.JoinedHost `json:"joined_hosts"` -} - -// QueryJoinedHostServerNamesRequest is a request to QueryJoinedHostServerNames -type QueryJoinedHostServerNamesInRoomRequest struct { - RoomID string `json:"room_id"` -} - -// QueryJoinedHostServerNamesResponse is a response to QueryJoinedHostServerNames -type QueryJoinedHostServerNamesInRoomResponse struct { - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` -} - -// FederationSenderQueryAPI is used to query information from the federation sender. -type FederationSenderQueryAPI interface { - // Query the joined hosts and the membership events accounting for their participation in a room. - // Note that if a server has multiple users in the room, it will have multiple entries in the returned slice. - // See `QueryJoinedHostServerNamesInRoom` for a de-duplicated version. - QueryJoinedHostsInRoom( - ctx context.Context, - request *QueryJoinedHostsInRoomRequest, - response *QueryJoinedHostsInRoomResponse, - ) error - // Query the server names of the joined hosts in a room. - // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice - // containing only the server names (without information for membership events). - QueryJoinedHostServerNamesInRoom( - ctx context.Context, - request *QueryJoinedHostServerNamesInRoomRequest, - response *QueryJoinedHostServerNamesInRoomResponse, - ) error -} - -// FederationSenderQueryJoinedHostsInRoomPath is the HTTP path for the QueryJoinedHostsInRoom API. -const FederationSenderQueryJoinedHostsInRoomPath = "/api/federationsender/queryJoinedHostsInRoom" - -// FederationSenderQueryJoinedHostServerNamesInRoomPath is the HTTP path for the QueryJoinedHostServerNamesInRoom API. -const FederationSenderQueryJoinedHostServerNamesInRoomPath = "/api/federationsender/queryJoinedHostServerNamesInRoom" - -// NewFederationSenderQueryAPIHTTP creates a FederationSenderQueryAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewFederationSenderQueryAPIHTTP(federationSenderURL string, httpClient *http.Client) (FederationSenderQueryAPI, error) { - if httpClient == nil { - return nil, errors.New("NewFederationSenderQueryAPIHTTP: httpClient is ") - } - return &httpFederationSenderQueryAPI{federationSenderURL, httpClient}, nil -} - -type httpFederationSenderQueryAPI struct { - federationSenderURL string - httpClient *http.Client -} - -// QueryJoinedHostsInRoom implements FederationSenderQueryAPI -func (h *httpFederationSenderQueryAPI) QueryJoinedHostsInRoom( - ctx context.Context, - request *QueryJoinedHostsInRoomRequest, - response *QueryJoinedHostsInRoomResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostsInRoom") - defer span.Finish() - - apiURL := h.federationSenderURL + FederationSenderQueryJoinedHostsInRoomPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryJoinedHostServerNamesInRoom implements FederationSenderQueryAPI -func (h *httpFederationSenderQueryAPI) QueryJoinedHostServerNamesInRoom( - ctx context.Context, - request *QueryJoinedHostServerNamesInRoomRequest, - response *QueryJoinedHostServerNamesInRoomResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom") - defer span.Finish() - - apiURL := h.federationSenderURL + FederationSenderQueryJoinedHostServerNamesInRoomPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} diff --git a/federationsender/consumers/eduserver.go b/federationsender/consumers/eduserver.go index 4d2445f3c..bcebb3ce1 100644 --- a/federationsender/consumers/eduserver.go +++ b/federationsender/consumers/eduserver.go @@ -16,19 +16,19 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/storage" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" - "gopkg.in/Shopify/sarama.v1" ) // OutputTypingEventConsumer consumes events that originate in EDU server. type OutputTypingEventConsumer struct { - consumer *common.ContinualConsumer + consumer *internal.ContinualConsumer db storage.Database queues *queue.OutgoingQueues ServerName gomatrixserverlib.ServerName @@ -41,7 +41,7 @@ func NewOutputTypingEventConsumer( queues *queue.OutgoingQueues, store storage.Database, ) *OutputTypingEventConsumer { - consumer := common.ContinualConsumer{ + consumer := internal.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputTypingEvent), Consumer: kafkaConsumer, PartitionStore: store, diff --git a/federationsender/consumers/roomserver.go b/federationsender/consumers/roomserver.go index f59405af0..299c7b37a 100644 --- a/federationsender/consumers/roomserver.go +++ b/federationsender/consumers/roomserver.go @@ -19,24 +19,25 @@ import ( "encoding/json" "fmt" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" - sarama "gopkg.in/Shopify/sarama.v1" + "github.com/tidwall/gjson" ) // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { - cfg *config.Dendrite - roomServerConsumer *common.ContinualConsumer - db storage.Database - queues *queue.OutgoingQueues - query api.RoomserverQueryAPI + cfg *config.Dendrite + rsAPI api.RoomserverInternalAPI + rsConsumer *internal.ContinualConsumer + db storage.Database + queues *queue.OutgoingQueues } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -45,19 +46,19 @@ func NewOutputRoomEventConsumer( kafkaConsumer sarama.Consumer, queues *queue.OutgoingQueues, store storage.Database, - queryAPI api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, ) *OutputRoomEventConsumer { - consumer := common.ContinualConsumer{ + consumer := internal.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputRoomEvent), Consumer: kafkaConsumer, PartitionStore: store, } s := &OutputRoomEventConsumer{ - cfg: cfg, - roomServerConsumer: &consumer, - db: store, - queues: queues, - query: queryAPI, + cfg: cfg, + rsConsumer: &consumer, + db: store, + queues: queues, + rsAPI: rsAPI, } consumer.ProcessMessage = s.onMessage @@ -66,7 +67,7 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - return s.roomServerConsumer.Start() + return s.rsConsumer.Start() } // onMessage is called when the federation server receives a new event from the room server output log. @@ -85,11 +86,6 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { switch output.Type { case api.OutputTypeNewRoomEvent: ev := &output.NewRoomEvent.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "send_as_server": output.NewRoomEvent.SendAsServer, - }).Info("received room event from roomserver") if err := s.processMessage(*output.NewRoomEvent); err != nil { // panic rather than continue with an inconsistent database @@ -130,11 +126,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error { - addsStateEvents, err := s.lookupStateEvents(ore.AddsStateEventIDs, ore.Event.Event) - if err != nil { - return err - } - addsJoinedHosts, err := joinedHostsFromEvents(addsStateEvents) + addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(ore.AddsState())) if err != nil { return err } @@ -187,49 +179,31 @@ func (s *OutputRoomEventConsumer) processInvite(oie api.OutputNewInviteEvent) er return nil } - // When sending a v2 invite, the inviting server should try and include - // a "stripped down" version of the room state. This is pretty much just - // enough information for the remote side to show something useful to the - // user, like the room name, aliases etc. - strippedState := []gomatrixserverlib.InviteV2StrippedState{} - stateWanted := []string{ - gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias, - gomatrixserverlib.MRoomAliases, gomatrixserverlib.MRoomJoinRules, + // Ignore invites that don't have state keys - they are invalid. + if oie.Event.StateKey() == nil { + return fmt.Errorf("event %q doesn't have state key", oie.Event.EventID()) } - // For each of the state keys that we want to try and send, ask the - // roomserver if we have a state event for that room that matches the - // state key. - for _, wanted := range stateWanted { - queryReq := api.QueryLatestEventsAndStateRequest{ - RoomID: oie.Event.RoomID(), - StateToFetch: []gomatrixserverlib.StateKeyTuple{ - gomatrixserverlib.StateKeyTuple{ - EventType: wanted, - StateKey: "", - }, - }, - } - // If this fails then we just move onto the next event - we don't - // actually know at this point whether the room even has that type - // of state. - queryRes := api.QueryLatestEventsAndStateResponse{} - if err := s.query.QueryLatestEventsAndState(context.TODO(), &queryReq, &queryRes); err != nil { - log.WithFields(log.Fields{ - "room_id": queryReq.RoomID, - "event_type": wanted, - }).WithError(err).Info("couldn't find state to strip") - continue - } - // Append the stripped down copy of the state to our list. - for _, headeredEvent := range queryRes.StateEvents { - event := headeredEvent.Unwrap() - strippedState = append(strippedState, gomatrixserverlib.NewInviteV2StrippedState(&event)) + // Don't try to handle events that are actually destined for us. + stateKey := *oie.Event.StateKey() + _, destination, err := gomatrixserverlib.SplitID('@', stateKey) + if err != nil { + log.WithFields(log.Fields{ + "event_id": oie.Event.EventID(), + "state_key": stateKey, + }).Info("failed to split destination from state key") + return nil + } + if s.cfg.Matrix.ServerName == destination { + return nil + } - log.WithFields(log.Fields{ - "room_id": queryReq.RoomID, - "event_type": event.Type(), - }).Info("adding stripped state") + // Try to extract the room invite state. The roomserver will have stashed + // this for us in invite_room_state if it didn't already exist. + strippedState := []gomatrixserverlib.InviteV2StrippedState{} + if inviteRoomState := gjson.GetBytes(oie.Event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { + if err = json.Unmarshal([]byte(inviteRoomState.Raw), &strippedState); err != nil { + log.WithError(err).Warn("failed to extract invite_room_state from event unsigned") } } @@ -405,7 +379,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( // from the roomserver using the query API. eventReq := api.QueryEventsByIDRequest{EventIDs: missing} var eventResp api.QueryEventsByIDResponse - if err := s.query.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { + if err := s.rsAPI.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { return nil, err } diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go index a318d2099..10ac51c8a 100644 --- a/federationsender/federationsender.go +++ b/federationsender/federationsender.go @@ -15,36 +15,51 @@ package federationsender import ( - "net/http" - - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/consumers" - "github.com/matrix-org/dendrite/federationsender/query" + "github.com/matrix-org/dendrite/federationsender/internal" + "github.com/matrix-org/dendrite/federationsender/inthttp" "github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/storage" + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal/setup" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) -// SetupFederationSenderComponent sets up and registers HTTP handlers for the -// FederationSender component. -func SetupFederationSenderComponent( - base *basecomponent.BaseDendrite, +// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions +// on the given input API. +func AddInternalRoutes(router *mux.Router, intAPI api.FederationSenderInternalAPI) { + inthttp.AddRoutes(intAPI, router) +} + +// NewInternalAPI returns a concerete implementation of the internal API. Callers +// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. +func NewInternalAPI( + base *setup.BaseDendrite, federation *gomatrixserverlib.FederationClient, - rsQueryAPI roomserverAPI.RoomserverQueryAPI, -) api.FederationSenderQueryAPI { - federationSenderDB, err := storage.NewDatabase(string(base.Cfg.Database.FederationSender)) + rsAPI roomserverAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, +) api.FederationSenderInternalAPI { + federationSenderDB, err := storage.NewDatabase(string(base.Cfg.Database.FederationSender), base.Cfg.DbProperties()) if err != nil { logrus.WithError(err).Panic("failed to connect to federation sender db") } - queues := queue.NewOutgoingQueues(base.Cfg.Matrix.ServerName, federation) + statistics := &types.Statistics{} + queues := queue.NewOutgoingQueues( + base.Cfg.Matrix.ServerName, federation, rsAPI, statistics, &queue.SigningInfo{ + KeyID: base.Cfg.Matrix.KeyID, + PrivateKey: base.Cfg.Matrix.PrivateKey, + ServerName: base.Cfg.Matrix.ServerName, + }, + ) rsConsumer := consumers.NewOutputRoomEventConsumer( base.Cfg, base.KafkaConsumer, queues, - federationSenderDB, rsQueryAPI, + federationSenderDB, rsAPI, ) if err = rsConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start room server consumer") @@ -57,10 +72,5 @@ func SetupFederationSenderComponent( logrus.WithError(err).Panic("failed to start typing server consumer") } - queryAPI := query.FederationSenderQueryAPI{ - DB: federationSenderDB, - } - queryAPI.SetupHTTP(http.DefaultServeMux) - - return &queryAPI + return internal.NewFederationSenderInternalAPI(federationSenderDB, base.Cfg, rsAPI, federation, keyRing, statistics, queues) } diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go new file mode 100644 index 000000000..0dca32fc9 --- /dev/null +++ b/federationsender/internal/api.go @@ -0,0 +1,40 @@ +package internal + +import ( + "github.com/matrix-org/dendrite/federationsender/queue" + "github.com/matrix-org/dendrite/federationsender/storage" + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +// FederationSenderInternalAPI is an implementation of api.FederationSenderInternalAPI +type FederationSenderInternalAPI struct { + db storage.Database + cfg *config.Dendrite + statistics *types.Statistics + rsAPI api.RoomserverInternalAPI + federation *gomatrixserverlib.FederationClient + keyRing *gomatrixserverlib.KeyRing + queues *queue.OutgoingQueues +} + +func NewFederationSenderInternalAPI( + db storage.Database, cfg *config.Dendrite, + rsAPI api.RoomserverInternalAPI, + federation *gomatrixserverlib.FederationClient, + keyRing *gomatrixserverlib.KeyRing, + statistics *types.Statistics, + queues *queue.OutgoingQueues, +) *FederationSenderInternalAPI { + return &FederationSenderInternalAPI{ + db: db, + cfg: cfg, + rsAPI: rsAPI, + federation: federation, + keyRing: keyRing, + statistics: statistics, + queues: queues, + } +} diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go new file mode 100644 index 000000000..7ced4af86 --- /dev/null +++ b/federationsender/internal/perform.go @@ -0,0 +1,292 @@ +package internal + +import ( + "context" + "fmt" + "time" + + "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/federationsender/internal/perform" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/version" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// PerformLeaveRequest implements api.FederationSenderInternalAPI +func (r *FederationSenderInternalAPI) PerformDirectoryLookup( + ctx context.Context, + request *api.PerformDirectoryLookupRequest, + response *api.PerformDirectoryLookupResponse, +) (err error) { + dir, err := r.federation.LookupRoomAlias( + ctx, + request.ServerName, + request.RoomAlias, + ) + if err != nil { + r.statistics.ForServer(request.ServerName).Failure() + return err + } + response.RoomID = dir.RoomID + response.ServerNames = dir.Servers + r.statistics.ForServer(request.ServerName).Success() + return nil +} + +// PerformJoinRequest implements api.FederationSenderInternalAPI +func (r *FederationSenderInternalAPI) PerformJoin( + ctx context.Context, + request *api.PerformJoinRequest, + response *api.PerformJoinResponse, +) (err error) { + // Look up the supported room versions. + var supportedVersions []gomatrixserverlib.RoomVersion + for version := range version.SupportedRoomVersions() { + supportedVersions = append(supportedVersions, version) + } + + // Deduplicate the server names we were provided but keep the ordering + // as this encodes useful information about which servers are most likely + // to respond. + seenSet := make(map[gomatrixserverlib.ServerName]bool) + var uniqueList []gomatrixserverlib.ServerName + for _, srv := range request.ServerNames { + if seenSet[srv] { + continue + } + seenSet[srv] = true + uniqueList = append(uniqueList, srv) + } + request.ServerNames = uniqueList + + // Try each server that we were provided until we land on one that + // successfully completes the make-join send-join dance. + for _, serverName := range request.ServerNames { + if err := r.performJoinUsingServer( + ctx, + request.RoomID, + request.UserID, + request.Content, + serverName, + supportedVersions, + ); err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "server_name": serverName, + "room_id": request.RoomID, + }).Warnf("Failed to join room through server") + continue + } + + // We're all good. + return nil + } + + // If we reach here then we didn't complete a join for some reason. + return fmt.Errorf( + "failed to join user %q to room %q through %d server(s)", + request.UserID, request.RoomID, len(request.ServerNames), + ) +} + +func (r *FederationSenderInternalAPI) performJoinUsingServer( + ctx context.Context, + roomID, userID string, + content map[string]interface{}, + serverName gomatrixserverlib.ServerName, + supportedVersions []gomatrixserverlib.RoomVersion, +) error { + // Try to perform a make_join using the information supplied in the + // request. + respMakeJoin, err := r.federation.MakeJoin( + ctx, + serverName, + roomID, + userID, + supportedVersions, + ) + if err != nil { + // TODO: Check if the user was not allowed to join the room. + r.statistics.ForServer(serverName).Failure() + return fmt.Errorf("r.federation.MakeJoin: %w", err) + } + r.statistics.ForServer(serverName).Success() + + // Set all the fields to be what they should be, this should be a no-op + // but it's possible that the remote server returned us something "odd" + respMakeJoin.JoinEvent.Type = gomatrixserverlib.MRoomMember + respMakeJoin.JoinEvent.Sender = userID + respMakeJoin.JoinEvent.StateKey = &userID + respMakeJoin.JoinEvent.RoomID = roomID + respMakeJoin.JoinEvent.Redacts = "" + if content == nil { + content = map[string]interface{}{} + } + content["membership"] = "join" + if err = respMakeJoin.JoinEvent.SetContent(content); err != nil { + return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err) + } + if err = respMakeJoin.JoinEvent.SetUnsigned(struct{}{}); err != nil { + return fmt.Errorf("respMakeJoin.JoinEvent.SetUnsigned: %w", err) + } + + // Work out if we support the room version that has been supplied in + // the make_join response. + if respMakeJoin.RoomVersion == "" { + respMakeJoin.RoomVersion = gomatrixserverlib.RoomVersionV1 + } + if _, err = respMakeJoin.RoomVersion.EventFormat(); err != nil { + return fmt.Errorf("respMakeJoin.RoomVersion.EventFormat: %w", err) + } + + // Build the join event. + event, err := respMakeJoin.JoinEvent.Build( + time.Now(), + r.cfg.Matrix.ServerName, + r.cfg.Matrix.KeyID, + r.cfg.Matrix.PrivateKey, + respMakeJoin.RoomVersion, + ) + if err != nil { + return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err) + } + + // Try to perform a send_join using the newly built event. + respSendJoin, err := r.federation.SendJoin( + ctx, + serverName, + event, + respMakeJoin.RoomVersion, + ) + if err != nil { + r.statistics.ForServer(serverName).Failure() + return fmt.Errorf("r.federation.SendJoin: %w", err) + } + r.statistics.ForServer(serverName).Success() + + // Check that the send_join response was valid. + joinCtx := perform.JoinContext(r.federation, r.keyRing) + if err = joinCtx.CheckSendJoinResponse( + ctx, event, serverName, respMakeJoin, respSendJoin, + ); err != nil { + return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err) + } + + // If we successfully performed a send_join above then the other + // server now thinks we're a part of the room. Send the newly + // returned state to the roomserver to update our local view. + respState := respSendJoin.ToRespState() + if err = roomserverAPI.SendEventWithState( + ctx, r.rsAPI, + &respState, + event.Headered(respMakeJoin.RoomVersion), nil, + ); err != nil { + return fmt.Errorf("r.producer.SendEventWithState: %w", err) + } + + return nil +} + +// PerformLeaveRequest implements api.FederationSenderInternalAPI +func (r *FederationSenderInternalAPI) PerformLeave( + ctx context.Context, + request *api.PerformLeaveRequest, + response *api.PerformLeaveResponse, +) (err error) { + // Deduplicate the server names we were provided. + util.SortAndUnique(request.ServerNames) + + // Try each server that we were provided until we land on one that + // successfully completes the make-leave send-leave dance. + for _, serverName := range request.ServerNames { + // Try to perform a make_leave using the information supplied in the + // request. + respMakeLeave, err := r.federation.MakeLeave( + ctx, + serverName, + request.RoomID, + request.UserID, + ) + if err != nil { + // TODO: Check if the user was not allowed to leave the room. + logrus.WithError(err).Warnf("r.federation.MakeLeave failed") + r.statistics.ForServer(serverName).Failure() + continue + } + + // Set all the fields to be what they should be, this should be a no-op + // but it's possible that the remote server returned us something "odd" + respMakeLeave.LeaveEvent.Type = gomatrixserverlib.MRoomMember + respMakeLeave.LeaveEvent.Sender = request.UserID + respMakeLeave.LeaveEvent.StateKey = &request.UserID + respMakeLeave.LeaveEvent.RoomID = request.RoomID + respMakeLeave.LeaveEvent.Redacts = "" + if respMakeLeave.LeaveEvent.Content == nil { + content := map[string]interface{}{ + "membership": "leave", + } + if err = respMakeLeave.LeaveEvent.SetContent(content); err != nil { + logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.SetContent failed") + continue + } + } + if err = respMakeLeave.LeaveEvent.SetUnsigned(struct{}{}); err != nil { + logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.SetUnsigned failed") + continue + } + + // Work out if we support the room version that has been supplied in + // the make_leave response. + if _, err = respMakeLeave.RoomVersion.EventFormat(); err != nil { + return gomatrixserverlib.UnsupportedRoomVersionError{} + } + + // Build the leave event. + event, err := respMakeLeave.LeaveEvent.Build( + time.Now(), + r.cfg.Matrix.ServerName, + r.cfg.Matrix.KeyID, + r.cfg.Matrix.PrivateKey, + respMakeLeave.RoomVersion, + ) + if err != nil { + logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.Build failed") + continue + } + + // Try to perform a send_leave using the newly built event. + err = r.federation.SendLeave( + ctx, + serverName, + event, + ) + if err != nil { + logrus.WithError(err).Warnf("r.federation.SendLeave failed") + r.statistics.ForServer(serverName).Failure() + continue + } + + r.statistics.ForServer(serverName).Success() + return nil + } + + // If we reach here then we didn't complete a leave for some reason. + return fmt.Errorf( + "Failed to leave room %q through %d server(s)", + request.RoomID, len(request.ServerNames), + ) +} + +// PerformServersAlive implements api.FederationSenderInternalAPI +func (r *FederationSenderInternalAPI) PerformServersAlive( + ctx context.Context, + request *api.PerformServersAliveRequest, + response *api.PerformServersAliveResponse, +) (err error) { + for _, srv := range request.Servers { + r.queues.RetryServer(srv) + } + + return nil +} diff --git a/federationsender/internal/perform/join.go b/federationsender/internal/perform/join.go new file mode 100644 index 000000000..9a505d15b --- /dev/null +++ b/federationsender/internal/perform/join.go @@ -0,0 +1,104 @@ +package perform + +import ( + "context" + "fmt" + + "github.com/matrix-org/gomatrixserverlib" +) + +// This file contains helpers for the PerformJoin function. + +type joinContext struct { + federation *gomatrixserverlib.FederationClient + keyRing *gomatrixserverlib.KeyRing +} + +// Returns a new join context. +func JoinContext(f *gomatrixserverlib.FederationClient, k *gomatrixserverlib.KeyRing) *joinContext { + return &joinContext{ + federation: f, + keyRing: k, + } +} + +// checkSendJoinResponse checks that all of the signatures are correct +// and that the join is allowed by the supplied state. +func (r joinContext) CheckSendJoinResponse( + ctx context.Context, + event gomatrixserverlib.Event, + server gomatrixserverlib.ServerName, + respMakeJoin gomatrixserverlib.RespMakeJoin, + respSendJoin gomatrixserverlib.RespSendJoin, +) error { + // A list of events that we have retried, if they were not included in + // the auth events supplied in the send_join. + retries := map[string][]gomatrixserverlib.Event{} + + // Define a function which we can pass to Check to retrieve missing + // auth events inline. This greatly increases our chances of not having + // to repeat the entire set of checks just for a missing event or two. + missingAuth := func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) { + returning := []gomatrixserverlib.Event{} + + // See if we have retry entries for each of the supplied event IDs. + for _, eventID := range eventIDs { + // If we've already satisfied a request for this event ID before then + // just append the results. We won't retry the request. + if retry, ok := retries[eventID]; ok { + if retry == nil { + return nil, fmt.Errorf("missingAuth: not retrying failed event ID %q", eventID) + } + returning = append(returning, retry...) + continue + } + + // Make a note of the fact that we tried to do something with this + // event ID, even if we don't succeed. + retries[event.EventID()] = nil + + // Try to retrieve the event from the server that sent us the send + // join response. + tx, txerr := r.federation.GetEvent(ctx, server, eventID) + if txerr != nil { + return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr) + } + + // For each event returned, add it to the set of return events. We + // also will populate the retries, in case someone asks for this + // event ID again. + for _, pdu := range tx.PDUs { + // Try to parse the event. + ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + if everr != nil { + return nil, fmt.Errorf("missingAuth gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr) + } + + // Check the signatures of the event. + if res, err := gomatrixserverlib.VerifyEventSignatures(ctx, []gomatrixserverlib.Event{ev}, r.keyRing); err != nil { + return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) + } else { + for _, err := range res { + if err != nil { + return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) + } + } + } + + // If the event is OK then add it to the results and the retry map. + returning = append(returning, ev) + retries[event.EventID()] = append(retries[event.EventID()], ev) + retries[ev.EventID()] = append(retries[ev.EventID()], ev) + } + } + + return returning, nil + } + + // TODO: Can we expand Check here to return a list of missing auth + // events rather than failing one at a time? + if err := respSendJoin.Check(ctx, r.keyRing, event, missingAuth); err != nil { + return fmt.Errorf("respSendJoin: %w", err) + } + return nil +} diff --git a/federationsender/internal/query.go b/federationsender/internal/query.go new file mode 100644 index 000000000..253400a2d --- /dev/null +++ b/federationsender/internal/query.go @@ -0,0 +1,29 @@ +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/gomatrixserverlib" +) + +// QueryJoinedHostServerNamesInRoom implements api.FederationSenderInternalAPI +func (f *FederationSenderInternalAPI) QueryJoinedHostServerNamesInRoom( + ctx context.Context, + request *api.QueryJoinedHostServerNamesInRoomRequest, + response *api.QueryJoinedHostServerNamesInRoomResponse, +) (err error) { + joinedHosts, err := f.db.GetJoinedHosts(ctx, request.RoomID) + if err != nil { + return + } + + response.ServerNames = make([]gomatrixserverlib.ServerName, 0, len(joinedHosts)) + for _, host := range joinedHosts { + response.ServerNames = append(response.ServerNames, host.ServerName) + } + + // TODO: remove duplicates? + + return +} diff --git a/federationsender/inthttp/client.go b/federationsender/inthttp/client.go new file mode 100644 index 000000000..5da4b35f9 --- /dev/null +++ b/federationsender/inthttp/client.go @@ -0,0 +1,99 @@ +package inthttp + +import ( + "context" + "errors" + "net/http" + + "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/opentracing/opentracing-go" +) + +// HTTP paths for the internal HTTP API +const ( + FederationSenderQueryJoinedHostServerNamesInRoomPath = "/federationsender/queryJoinedHostServerNamesInRoom" + + FederationSenderPerformDirectoryLookupRequestPath = "/federationsender/performDirectoryLookup" + FederationSenderPerformJoinRequestPath = "/federationsender/performJoinRequest" + FederationSenderPerformLeaveRequestPath = "/federationsender/performLeaveRequest" + FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive" +) + +// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. +// If httpClient is nil an error is returned +func NewFederationSenderClient(federationSenderURL string, httpClient *http.Client) (api.FederationSenderInternalAPI, error) { + if httpClient == nil { + return nil, errors.New("NewFederationSenderInternalAPIHTTP: httpClient is ") + } + return &httpFederationSenderInternalAPI{federationSenderURL, httpClient}, nil +} + +type httpFederationSenderInternalAPI struct { + federationSenderURL string + httpClient *http.Client +} + +// Handle an instruction to make_leave & send_leave with a remote server. +func (h *httpFederationSenderInternalAPI) PerformLeave( + ctx context.Context, + request *api.PerformLeaveRequest, + response *api.PerformLeaveResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeaveRequest") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderPerformLeaveRequestPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpFederationSenderInternalAPI) PerformServersAlive( + ctx context.Context, + request *api.PerformServersAliveRequest, + response *api.PerformServersAliveResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformServersAlive") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderPerformServersAlivePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryJoinedHostServerNamesInRoom implements FederationSenderInternalAPI +func (h *httpFederationSenderInternalAPI) QueryJoinedHostServerNamesInRoom( + ctx context.Context, + request *api.QueryJoinedHostServerNamesInRoomRequest, + response *api.QueryJoinedHostServerNamesInRoomResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderQueryJoinedHostServerNamesInRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// Handle an instruction to make_join & send_join with a remote server. +func (h *httpFederationSenderInternalAPI) PerformJoin( + ctx context.Context, + request *api.PerformJoinRequest, + response *api.PerformJoinResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoinRequest") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderPerformJoinRequestPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// Handle an instruction to make_join & send_join with a remote server. +func (h *httpFederationSenderInternalAPI) PerformDirectoryLookup( + ctx context.Context, + request *api.PerformDirectoryLookupRequest, + response *api.PerformDirectoryLookupResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDirectoryLookup") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderPerformDirectoryLookupRequestPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/federationsender/inthttp/server.go b/federationsender/inthttp/server.go new file mode 100644 index 000000000..babd3ae13 --- /dev/null +++ b/federationsender/inthttp/server.go @@ -0,0 +1,81 @@ +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/util" +) + +// AddRoutes adds the FederationSenderInternalAPI handlers to the http.ServeMux. +func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Router) { + internalAPIMux.Handle( + FederationSenderQueryJoinedHostServerNamesInRoomPath, + httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse { + var request api.QueryJoinedHostServerNamesInRoomRequest + var response api.QueryJoinedHostServerNamesInRoomResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := intAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(FederationSenderPerformJoinRequestPath, + httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse { + var request api.PerformJoinRequest + var response api.PerformJoinResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.PerformJoin(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(FederationSenderPerformLeaveRequestPath, + httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse { + var request api.PerformLeaveRequest + var response api.PerformLeaveResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.PerformLeave(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(FederationSenderPerformDirectoryLookupRequestPath, + httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse { + var request api.PerformDirectoryLookupRequest + var response api.PerformDirectoryLookupResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.PerformDirectoryLookup(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(FederationSenderPerformServersAlivePath, + httputil.MakeInternalAPI("PerformServersAliveRequest", func(req *http.Request) util.JSONResponse { + var request api.PerformServersAliveRequest + var response api.PerformServersAliveResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.PerformServersAlive(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/federationsender/query/query.go b/federationsender/query/query.go deleted file mode 100644 index 8c35bb29e..000000000 --- a/federationsender/query/query.go +++ /dev/null @@ -1,88 +0,0 @@ -package query - -import ( - "context" - "encoding/json" - "net/http" - - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/dendrite/federationsender/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" -) - -// FederationSenderQueryDatabase has the APIs needed to implement the query API. -type FederationSenderQueryDatabase interface { - GetJoinedHosts( - ctx context.Context, roomID string, - ) ([]types.JoinedHost, error) -} - -// FederationSenderQueryAPI is an implementation of api.FederationSenderQueryAPI -type FederationSenderQueryAPI struct { - DB FederationSenderQueryDatabase -} - -// QueryJoinedHostsInRoom implements api.FederationSenderQueryAPI -func (f *FederationSenderQueryAPI) QueryJoinedHostsInRoom( - ctx context.Context, - request *api.QueryJoinedHostsInRoomRequest, - response *api.QueryJoinedHostsInRoomResponse, -) (err error) { - response.JoinedHosts, err = f.DB.GetJoinedHosts(ctx, request.RoomID) - return -} - -// QueryJoinedHostServerNamesInRoom implements api.FederationSenderQueryAPI -func (f *FederationSenderQueryAPI) QueryJoinedHostServerNamesInRoom( - ctx context.Context, - request *api.QueryJoinedHostServerNamesInRoomRequest, - response *api.QueryJoinedHostServerNamesInRoomResponse, -) (err error) { - joinedHosts, err := f.DB.GetJoinedHosts(ctx, request.RoomID) - if err != nil { - return - } - - response.ServerNames = make([]gomatrixserverlib.ServerName, 0, len(joinedHosts)) - for _, host := range joinedHosts { - response.ServerNames = append(response.ServerNames, host.ServerName) - } - - // TODO: remove duplicates? - - return -} - -// SetupHTTP adds the FederationSenderQueryAPI handlers to the http.ServeMux. -func (f *FederationSenderQueryAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle( - api.FederationSenderQueryJoinedHostsInRoomPath, - common.MakeInternalAPI("QueryJoinedHostsInRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryJoinedHostsInRoomRequest - var response api.QueryJoinedHostsInRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := f.QueryJoinedHostsInRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.FederationSenderQueryJoinedHostServerNamesInRoomPath, - common.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryJoinedHostServerNamesInRoomRequest - var response api.QueryJoinedHostServerNamesInRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := f.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) -} diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index 7d4dc850b..4449f9e63 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -18,11 +18,13 @@ import ( "context" "encoding/json" "fmt" - "sync" "time" + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "go.uber.org/atomic" ) @@ -32,91 +34,244 @@ import ( // ensures that only one request is in flight to a given destination // at a time. type destinationQueue struct { - client *gomatrixserverlib.FederationClient - origin gomatrixserverlib.ServerName - destination gomatrixserverlib.ServerName - running atomic.Bool - // The running mutex protects sentCounter, lastTransactionIDs and - // pendingEvents, pendingEDUs. - runningMutex sync.Mutex - sentCounter int - lastTransactionIDs []gomatrixserverlib.TransactionID - pendingEvents []*gomatrixserverlib.HeaderedEvent - pendingEDUs []*gomatrixserverlib.EDU - pendingInvites []*gomatrixserverlib.InviteV2Request + signing *SigningInfo + rsAPI api.RoomserverInternalAPI + client *gomatrixserverlib.FederationClient // federation client + origin gomatrixserverlib.ServerName // origin of requests + destination gomatrixserverlib.ServerName // destination of requests + running atomic.Bool // is the queue worker running? + backingOff atomic.Bool // true if we're backing off + statistics *types.ServerStatistics // statistics about this remote server + incomingPDUs chan *gomatrixserverlib.HeaderedEvent // PDUs to send + incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send + incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send + lastTransactionIDs []gomatrixserverlib.TransactionID // last transaction ID + pendingPDUs []*gomatrixserverlib.HeaderedEvent // owned by backgroundSend + pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend + pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend + retryServerCh chan bool // interrupts backoff +} + +// retry will clear the blacklist state and attempt to send built up events to the server, +// resetting and interrupting any backoff timers. +func (oq *destinationQueue) retry() { + // TODO: We don't send all events in the case where the server has been blacklisted as we + // drop events instead then. This means we will send the oldest N events (chan size, currently 128) + // and then skip ahead a lot which feels non-ideal but equally we can't persist thousands of events + // in-memory to maybe-send it one day. Ideally we would just shove these pending events in a database + // so we can send a lot of events. + // + // Interrupt the backoff. If the federation request that happens as a result of this is successful + // then the counters will be reset there and the backoff will cancel. If the federation request + // fails then we will retry at the current backoff interval, so as to prevent us from spamming + // homeservers which are behaving badly. + // We need to use an atomic bool here to prevent multiple calls to retry() blocking on the channel + // as it is unbuffered. + if oq.backingOff.CAS(true, false) { + oq.retryServerCh <- true + } + if !oq.running.Load() { + log.Infof("Restarting queue for %s", oq.destination) + go oq.backgroundSend() + } } // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. func (oq *destinationQueue) sendEvent(ev *gomatrixserverlib.HeaderedEvent) { - oq.runningMutex.Lock() - defer oq.runningMutex.Unlock() - oq.pendingEvents = append(oq.pendingEvents, ev) + if oq.statistics.Blacklisted() { + // If the destination is blacklisted then drop the event. + return + } if !oq.running.Load() { go oq.backgroundSend() } + oq.incomingPDUs <- ev } // sendEDU adds the EDU event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEDU(e *gomatrixserverlib.EDU) { - oq.runningMutex.Lock() - defer oq.runningMutex.Unlock() - oq.pendingEDUs = append(oq.pendingEDUs, e) +func (oq *destinationQueue) sendEDU(ev *gomatrixserverlib.EDU) { + if oq.statistics.Blacklisted() { + // If the destination is blacklisted then drop the event. + return + } if !oq.running.Load() { go oq.backgroundSend() } + oq.incomingEDUs <- ev } // sendInvite adds the invite event to the pending queue for the // destination. If the queue is empty then it starts a background // goroutine to start sending events to that destination. func (oq *destinationQueue) sendInvite(ev *gomatrixserverlib.InviteV2Request) { - oq.runningMutex.Lock() - defer oq.runningMutex.Unlock() - oq.pendingInvites = append(oq.pendingInvites, ev) + if oq.statistics.Blacklisted() { + // If the destination is blacklisted then drop the event. + return + } if !oq.running.Load() { go oq.backgroundSend() } + oq.incomingInvites <- ev } // backgroundSend is the worker goroutine for sending events. +// nolint:gocyclo func (oq *destinationQueue) backgroundSend() { - oq.running.Store(true) + // Check if a worker is already running, and if it isn't, then + // mark it as started. + if !oq.running.CAS(false, true) { + return + } defer oq.running.Store(false) for { - transaction, invites := oq.nextTransaction(), oq.nextInvites() - if !transaction && !invites { - // If the queue is empty then stop processing for this destination. - // TODO: Remove this destination from the queue map. + // Wait either for incoming events, or until we hit an + // idle timeout. + select { + case pdu := <-oq.incomingPDUs: + // Ordering of PDUs is important so we add them to the end + // of the queue and they will all be added to transactions + // in order. + oq.pendingPDUs = append(oq.pendingPDUs, pdu) + // If there are any more things waiting in the channel queue + // then read them. This is safe because we guarantee only + // having one goroutine per destination queue, so the channel + // isn't being consumed anywhere else. + for len(oq.incomingPDUs) > 0 { + oq.pendingPDUs = append(oq.pendingPDUs, <-oq.incomingPDUs) + } + case edu := <-oq.incomingEDUs: + // Likewise for EDUs, although we should probably not try + // too hard with some EDUs (like typing notifications) after + // a certain amount of time has passed. + // TODO: think about EDU expiry some more + oq.pendingEDUs = append(oq.pendingEDUs, edu) + // If there are any more things waiting in the channel queue + // then read them. This is safe because we guarantee only + // having one goroutine per destination queue, so the channel + // isn't being consumed anywhere else. + for len(oq.incomingEDUs) > 0 { + oq.pendingEDUs = append(oq.pendingEDUs, <-oq.incomingEDUs) + } + case invite := <-oq.incomingInvites: + // There's no strict ordering requirement for invites like + // there is for transactions, so we put the invite onto the + // front of the queue. This means that if an invite that is + // stuck failing already, that it won't block our new invite + // from being sent. + oq.pendingInvites = append( + []*gomatrixserverlib.InviteV2Request{invite}, + oq.pendingInvites..., + ) + // If there are any more things waiting in the channel queue + // then read them. This is safe because we guarantee only + // having one goroutine per destination queue, so the channel + // isn't being consumed anywhere else. + for len(oq.incomingInvites) > 0 { + oq.pendingInvites = append(oq.pendingInvites, <-oq.incomingInvites) + } + case <-time.After(time.Second * 30): + // The worker is idle so stop the goroutine. It'll + // get restarted automatically the next time we + // get an event. return } - // TODO: handle retries. - // TODO: blacklist uncooperative servers. + // If we are backing off this server then wait for the + // backoff duration to complete first, or until explicitly + // told to retry. + if backoff, duration := oq.statistics.BackoffDuration(); backoff { + oq.backingOff.Store(true) + select { + case <-time.After(duration): + case <-oq.retryServerCh: + } + oq.backingOff.Store(false) + } + + // How many things do we have waiting? + numPDUs := len(oq.pendingPDUs) + numEDUs := len(oq.pendingEDUs) + numInvites := len(oq.pendingInvites) + + // If we have pending PDUs or EDUs then construct a transaction. + if numPDUs > 0 || numEDUs > 0 { + // Try sending the next transaction and see what happens. + transaction, terr := oq.nextTransaction(oq.pendingPDUs, oq.pendingEDUs, oq.statistics.SuccessCount()) + if terr != nil { + // We failed to send the transaction. + if giveUp := oq.statistics.Failure(); giveUp { + // It's been suggested that we should give up because + // the backoff has exceeded a maximum allowable value. + return + } + } else if transaction { + // If we successfully sent the transaction then clear out + // the pending events and EDUs. + oq.statistics.Success() + // Reallocate so that the underlying arrays can be GC'd, as + // opposed to growing forever. + for i := 0; i < numPDUs; i++ { + oq.pendingPDUs[i] = nil + } + for i := 0; i < numEDUs; i++ { + oq.pendingEDUs[i] = nil + } + oq.pendingPDUs = append( + []*gomatrixserverlib.HeaderedEvent{}, + oq.pendingPDUs[numPDUs:]..., + ) + oq.pendingEDUs = append( + []*gomatrixserverlib.EDU{}, + oq.pendingEDUs[numEDUs:]..., + ) + } + } + + // Try sending the next invite and see what happens. + if numInvites > 0 { + sent, ierr := oq.nextInvites(oq.pendingInvites) + if ierr != nil { + // We failed to send the transaction so increase the + // backoff and give it another go shortly. + if giveUp := oq.statistics.Failure(); giveUp { + // It's been suggested that we should give up because + // the backoff has exceeded a maximum allowable value. + return + } + } else if sent > 0 { + // If we successfully sent the invites then clear out + // the pending invites. + oq.statistics.Success() + // Reallocate so that the underlying array can be GC'd, as + // opposed to growing forever. + oq.pendingInvites = append( + []*gomatrixserverlib.InviteV2Request{}, + oq.pendingInvites[sent:]..., + ) + } + } } } // nextTransaction creates a new transaction from the pending event // queue and sends it. Returns true if a transaction was sent or // false otherwise. -func (oq *destinationQueue) nextTransaction() bool { - oq.runningMutex.Lock() - defer oq.runningMutex.Unlock() - - if len(oq.pendingEvents) == 0 && len(oq.pendingEDUs) == 0 { - return false - } - +func (oq *destinationQueue) nextTransaction( + pendingPDUs []*gomatrixserverlib.HeaderedEvent, + pendingEDUs []*gomatrixserverlib.EDU, + sentCounter uint32, +) (bool, error) { t := gomatrixserverlib.Transaction{ PDUs: []json.RawMessage{}, EDUs: []gomatrixserverlib.EDU{}, } now := gomatrixserverlib.AsTimestamp(time.Now()) - t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.sentCounter)) + t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, sentCounter)) t.Origin = oq.origin t.Destination = oq.destination t.OriginServerTS = now @@ -127,60 +282,106 @@ func (oq *destinationQueue) nextTransaction() bool { oq.lastTransactionIDs = []gomatrixserverlib.TransactionID{t.TransactionID} - for _, pdu := range oq.pendingEvents { + for _, pdu := range pendingPDUs { // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct t.PDUs = append(t.PDUs, (*pdu).JSON()) } - oq.pendingEvents = nil - oq.sentCounter += len(t.PDUs) - for _, edu := range oq.pendingEDUs { + for _, edu := range pendingEDUs { t.EDUs = append(t.EDUs, *edu) } - oq.pendingEDUs = nil - oq.sentCounter += len(t.EDUs) - util.GetLogger(context.TODO()).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) + logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) + // TODO: we should check for 500-ish fails vs 400-ish here, + // since we shouldn't queue things indefinitely in response + // to a 400-ish error _, err := oq.client.SendTransaction(context.TODO(), t) - if err != nil { + switch e := err.(type) { + case nil: + // No error was returned so the transaction looks to have + // been successfully sent. + return true, nil + case gomatrix.HTTPError: + // We received a HTTP error back. In this instance we only + // should report an error if + if e.Code >= 400 && e.Code <= 499 { + // We tried but the remote side has sent back a client error. + // It's no use retrying because it will happen again. + return true, nil + } + // Otherwise, report that we failed to send the transaction + // and we will retry again. + return false, err + default: log.WithFields(log.Fields{ "destination": oq.destination, log.ErrorKey: err, }).Info("problem sending transaction") + return false, err } - - return true } // nextInvite takes pending invite events from the queue and sends // them. Returns true if a transaction was sent or false otherwise. -func (oq *destinationQueue) nextInvites() bool { - oq.runningMutex.Lock() - defer oq.runningMutex.Unlock() +func (oq *destinationQueue) nextInvites( + pendingInvites []*gomatrixserverlib.InviteV2Request, +) (int, error) { + done := 0 + for _, inviteReq := range pendingInvites { + ev, roomVersion := inviteReq.Event(), inviteReq.RoomVersion() - if len(oq.pendingInvites) == 0 { - return false - } + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "room_version": roomVersion, + "destination": oq.destination, + }).Info("sending invite") - for _, inviteReq := range oq.pendingInvites { - ev := inviteReq.Event() - - if _, err := oq.client.SendInviteV2( + inviteRes, err := oq.client.SendInviteV2( context.TODO(), oq.destination, *inviteReq, - ); err != nil { + ) + switch e := err.(type) { + case nil: + done++ + case gomatrix.HTTPError: + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "state_key": ev.StateKey(), + "destination": oq.destination, + "status_code": e.Code, + }).WithError(err).Error("failed to send invite due to HTTP error") + // Check whether we should do something about the error or + // just accept it as unavoidable. + if e.Code >= 400 && e.Code <= 499 { + // We tried but the remote side has sent back a client error. + // It's no use retrying because it will happen again. + done++ + continue + } + return done, err + default: log.WithFields(log.Fields{ "event_id": ev.EventID(), "state_key": ev.StateKey(), "destination": oq.destination, }).WithError(err).Error("failed to send invite") + return done, err + } + + invEv := inviteRes.Event.Sign(string(oq.signing.ServerName), oq.signing.KeyID, oq.signing.PrivateKey).Headered(roomVersion) + _, err = api.SendEvents(context.TODO(), oq.rsAPI, []gomatrixserverlib.HeaderedEvent{invEv}, oq.signing.ServerName, nil) + if err != nil { + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "state_key": ev.StateKey(), + "destination": oq.destination, + }).WithError(err).Error("failed to return signed invite to roomserver") + return done, err } } - oq.pendingInvites = nil - - return true + return done, nil } diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 88d47f120..240343559 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -15,32 +15,83 @@ package queue import ( + "crypto/ed25519" "fmt" "sync" + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) // OutgoingQueues is a collection of queues for sending transactions to other // matrix servers type OutgoingQueues struct { - origin gomatrixserverlib.ServerName - client *gomatrixserverlib.FederationClient - // The queuesMutex protects queues - queuesMutex sync.Mutex + rsAPI api.RoomserverInternalAPI + origin gomatrixserverlib.ServerName + client *gomatrixserverlib.FederationClient + statistics *types.Statistics + signing *SigningInfo + queuesMutex sync.Mutex // protects the below queues map[gomatrixserverlib.ServerName]*destinationQueue } // NewOutgoingQueues makes a new OutgoingQueues -func NewOutgoingQueues(origin gomatrixserverlib.ServerName, client *gomatrixserverlib.FederationClient) *OutgoingQueues { +func NewOutgoingQueues( + origin gomatrixserverlib.ServerName, + client *gomatrixserverlib.FederationClient, + rsAPI api.RoomserverInternalAPI, + statistics *types.Statistics, + signing *SigningInfo, +) *OutgoingQueues { return &OutgoingQueues{ - origin: origin, - client: client, - queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, + rsAPI: rsAPI, + origin: origin, + client: client, + statistics: statistics, + signing: signing, + queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, } } +// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables +// around together +type SigningInfo struct { + ServerName gomatrixserverlib.ServerName + KeyID gomatrixserverlib.KeyID + PrivateKey ed25519.PrivateKey +} + +func (oqs *OutgoingQueues) getQueueIfExists(destination gomatrixserverlib.ServerName) *destinationQueue { + oqs.queuesMutex.Lock() + defer oqs.queuesMutex.Unlock() + return oqs.queues[destination] +} + +func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { + oqs.queuesMutex.Lock() + defer oqs.queuesMutex.Unlock() + oq := oqs.queues[destination] + if oq == nil { + oq = &destinationQueue{ + rsAPI: oqs.rsAPI, + origin: oqs.origin, + destination: destination, + client: oqs.client, + statistics: oqs.statistics.ForServer(destination), + incomingPDUs: make(chan *gomatrixserverlib.HeaderedEvent, 128), + incomingEDUs: make(chan *gomatrixserverlib.EDU, 128), + incomingInvites: make(chan *gomatrixserverlib.InviteV2Request, 128), + retryServerCh: make(chan bool), + signing: oqs.signing, + } + oqs.queues[destination] = oq + } + return oq +} + // SendEvent sends an event to the destinations func (oqs *OutgoingQueues) SendEvent( ev *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, @@ -55,26 +106,17 @@ func (oqs *OutgoingQueues) SendEvent( } // Remove our own server from the list of destinations. - destinations = filterDestinations(oqs.origin, destinations) + destinations = filterAndDedupeDests(oqs.origin, destinations) + if len(destinations) == 0 { + return nil + } log.WithFields(log.Fields{ "destinations": destinations, "event": ev.EventID(), }).Info("Sending event") - oqs.queuesMutex.Lock() - defer oqs.queuesMutex.Unlock() for _, destination := range destinations { - oq := oqs.queues[destination] - if oq == nil { - oq = &destinationQueue{ - origin: oqs.origin, - destination: destination, - client: oqs.client, - } - oqs.queues[destination] = oq - } - - oq.sendEvent(ev) + oqs.getQueue(destination).sendEvent(ev) } return nil @@ -103,22 +145,11 @@ func (oqs *OutgoingQueues) SendInvite( } log.WithFields(log.Fields{ - "event_id": ev.EventID(), + "event_id": ev.EventID(), + "server_name": destination, }).Info("Sending invite") - oqs.queuesMutex.Lock() - defer oqs.queuesMutex.Unlock() - oq := oqs.queues[destination] - if oq == nil { - oq = &destinationQueue{ - origin: oqs.origin, - destination: destination, - client: oqs.client, - } - oqs.queues[destination] = oq - } - - oq.sendInvite(inviteReq) + oqs.getQueue(destination).sendInvite(inviteReq) return nil } @@ -137,7 +168,7 @@ func (oqs *OutgoingQueues) SendEDU( } // Remove our own server from the list of destinations. - destinations = filterDestinations(oqs.origin, destinations) + destinations = filterAndDedupeDests(oqs.origin, destinations) if len(destinations) > 0 { log.WithFields(log.Fields{ @@ -145,34 +176,36 @@ func (oqs *OutgoingQueues) SendEDU( }).Info("Sending EDU event") } - oqs.queuesMutex.Lock() - defer oqs.queuesMutex.Unlock() for _, destination := range destinations { - oq := oqs.queues[destination] - if oq == nil { - oq = &destinationQueue{ - origin: oqs.origin, - destination: destination, - client: oqs.client, - } - oqs.queues[destination] = oq - } - - oq.sendEDU(e) + oqs.getQueue(destination).sendEDU(e) } return nil } -// filterDestinations removes our own server from the list of destinations. -// Otherwise we could end up trying to talk to ourselves. -func filterDestinations(origin gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName) []gomatrixserverlib.ServerName { - var result []gomatrixserverlib.ServerName - for _, destination := range destinations { - if destination == origin { +// RetryServer attempts to resend events to the given server if we had given up. +func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { + q := oqs.getQueueIfExists(srv) + if q == nil { + return + } + q.retry() +} + +// filterAndDedupeDests removes our own server from the list of destinations +// and deduplicates any servers in the list that may appear more than once. +func filterAndDedupeDests(origin gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName) ( + result []gomatrixserverlib.ServerName, +) { + strs := make([]string, len(destinations)) + for i, d := range destinations { + strs[i] = string(d) + } + for _, destination := range util.UniqueStrings(strs) { + if gomatrixserverlib.ServerName(destination) == origin { continue } - result = append(result, destination) + result = append(result, gomatrixserverlib.ServerName(destination)) } return result } diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index ae2956475..be195382b 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -17,12 +17,12 @@ package storage import ( "context" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal" ) type Database interface { - common.PartitionStorer + internal.PartitionStorer UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) } diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index b3c45abda..c0f9a7d5b 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -20,8 +20,9 @@ import ( "database/sql" "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -85,7 +86,7 @@ func (s *joinedHostsStatements) insertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - stmt := common.TxStmt(txn, s.insertJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) return err } @@ -93,7 +94,7 @@ func (s *joinedHostsStatements) insertJoinedHosts( func (s *joinedHostsStatements) deleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { - stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) return err } @@ -101,7 +102,7 @@ func (s *joinedHostsStatements) deleteJoinedHosts( func (s *joinedHostsStatements) selectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { - stmt := common.TxStmt(txn, s.selectJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) return joinedHostsFromStmt(ctx, stmt, roomID) } @@ -118,7 +119,7 @@ func joinedHostsFromStmt( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed") var result []types.JoinedHost for rows.Next() { diff --git a/federationsender/storage/postgres/room_table.go b/federationsender/storage/postgres/room_table.go index a64424b44..e5266c635 100644 --- a/federationsender/storage/postgres/room_table.go +++ b/federationsender/storage/postgres/room_table.go @@ -19,7 +19,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" ) const roomSchema = ` @@ -71,7 +71,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { func (s *roomStatements) insertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - _, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) + _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) return err } @@ -82,7 +82,7 @@ func (s *roomStatements) selectRoomForUpdate( ctx context.Context, txn *sql.Tx, roomID string, ) (string, error) { var lastEventID string - stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomForUpdateStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID) if err != nil { return "", err @@ -95,7 +95,7 @@ func (s *roomStatements) selectRoomForUpdate( func (s *roomStatements) updateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { - stmt := common.TxStmt(txn, s.updateRoomStmt) + stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) _, err := stmt.ExecContext(ctx, roomID, lastEventID) return err } diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index b909a189b..8fd4c11a3 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -19,7 +19,6 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal/sqlutil" ) @@ -28,15 +27,15 @@ import ( type Database struct { joinedHostsStatements roomStatements - common.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements db *sql.DB } // NewDatabase opens a new database -func NewDatabase(dataSourceName string) (*Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { var result Database var err error - if result.db, err = sqlutil.Open("postgres", dataSourceName); err != nil { + if result.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } if err = result.prepare(); err != nil { @@ -70,7 +69,7 @@ func (d *Database) UpdateRoom( addHosts []types.JoinedHost, removeHosts []string, ) (joinedHosts []types.JoinedHost, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = d.insertRoom(ctx, txn, roomID) if err != nil { return err diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 466ae4991..d9824658c 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -19,8 +19,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -84,7 +85,7 @@ func (s *joinedHostsStatements) insertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - stmt := common.TxStmt(txn, s.insertJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) return err } @@ -93,7 +94,7 @@ func (s *joinedHostsStatements) deleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { for _, eventID := range eventIDs { - stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) if _, err := stmt.ExecContext(ctx, eventID); err != nil { return err } @@ -104,7 +105,7 @@ func (s *joinedHostsStatements) deleteJoinedHosts( func (s *joinedHostsStatements) selectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { - stmt := common.TxStmt(txn, s.selectJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) return joinedHostsFromStmt(ctx, stmt, roomID) } @@ -121,7 +122,7 @@ func joinedHostsFromStmt( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed") var result []types.JoinedHost for rows.Next() { diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go index 6361400d3..ca0c4d0b6 100644 --- a/federationsender/storage/sqlite3/room_table.go +++ b/federationsender/storage/sqlite3/room_table.go @@ -19,7 +19,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" ) const roomSchema = ` @@ -71,7 +71,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { func (s *roomStatements) insertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - _, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) + _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) return err } @@ -82,7 +82,7 @@ func (s *roomStatements) selectRoomForUpdate( ctx context.Context, txn *sql.Tx, roomID string, ) (string, error) { var lastEventID string - stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomForUpdateStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID) if err != nil { return "", err @@ -95,7 +95,7 @@ func (s *roomStatements) selectRoomForUpdate( func (s *roomStatements) updateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { - stmt := common.TxStmt(txn, s.updateRoomStmt) + stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) _, err := stmt.ExecContext(ctx, roomID, lastEventID) return err } diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 458d7d7e5..ac303f646 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -21,7 +21,6 @@ import ( _ "github.com/mattn/go-sqlite3" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal/sqlutil" ) @@ -30,7 +29,7 @@ import ( type Database struct { joinedHostsStatements roomStatements - common.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements db *sql.DB } @@ -38,7 +37,11 @@ type Database struct { func NewDatabase(dataSourceName string) (*Database, error) { var result Database var err error - if result.db, err = sqlutil.Open(common.SQLiteDriverName(), dataSourceName); err != nil { + cs, err := sqlutil.ParseFileURI(dataSourceName) + if err != nil { + return nil, err + } + if result.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } if err = result.prepare(); err != nil { @@ -72,7 +75,7 @@ func (d *Database) UpdateRoom( addHosts []types.JoinedHost, removeHosts []string, ) (joinedHosts []types.JoinedHost, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = d.insertRoom(ctx, txn, roomID) if err != nil { return err diff --git a/federationsender/storage/storage.go b/federationsender/storage/storage.go index 2f018dff1..d37360056 100644 --- a/federationsender/storage/storage.go +++ b/federationsender/storage/storage.go @@ -21,20 +21,21 @@ import ( "github.com/matrix-org/dendrite/federationsender/storage/postgres" "github.com/matrix-org/dendrite/federationsender/storage/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" ) // NewDatabase opens a new database -func NewDatabase(dataSourceName string) (Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.NewDatabase(dataSourceName) + return postgres.NewDatabase(dataSourceName, dbProperties) } switch uri.Scheme { case "file": return sqlite3.NewDatabase(dataSourceName) case "postgres": - return postgres.NewDatabase(dataSourceName) + return postgres.NewDatabase(dataSourceName, dbProperties) default: - return postgres.NewDatabase(dataSourceName) + return postgres.NewDatabase(dataSourceName, dbProperties) } } diff --git a/federationsender/storage/storage_wasm.go b/federationsender/storage/storage_wasm.go index f2c8ae1b4..e5c8f293b 100644 --- a/federationsender/storage/storage_wasm.go +++ b/federationsender/storage/storage_wasm.go @@ -19,10 +19,14 @@ import ( "net/url" "github.com/matrix-org/dendrite/federationsender/storage/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" ) // NewDatabase opens a new database -func NewDatabase(dataSourceName string) (Database, error) { +func NewDatabase( + dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam +) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return nil, fmt.Errorf("Cannot use postgres implementation") diff --git a/federationsender/types/statistics.go b/federationsender/types/statistics.go new file mode 100644 index 000000000..63f82756e --- /dev/null +++ b/federationsender/types/statistics.go @@ -0,0 +1,122 @@ +package types + +import ( + "math" + "sync" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "go.uber.org/atomic" +) + +const ( + // How many times should we tolerate consecutive failures before we + // just blacklist the host altogether? Bear in mind that the backoff + // is exponential, so the max time here to attempt is 2**failures. + FailuresUntilBlacklist = 16 // 16 equates to roughly 18 hours. +) + +// Statistics contains information about all of the remote federated +// hosts that we have interacted with. It is basically a threadsafe +// wrapper. +type Statistics struct { + servers map[gomatrixserverlib.ServerName]*ServerStatistics + mutex sync.RWMutex +} + +// ForServer returns server statistics for the given server name. If it +// does not exist, it will create empty statistics and return those. +func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { + // If the map hasn't been initialised yet then do that. + if s.servers == nil { + s.mutex.Lock() + s.servers = make(map[gomatrixserverlib.ServerName]*ServerStatistics) + s.mutex.Unlock() + } + // Look up if we have statistics for this server already. + s.mutex.RLock() + server, found := s.servers[serverName] + s.mutex.RUnlock() + // If we don't, then make one. + if !found { + s.mutex.Lock() + server = &ServerStatistics{} + s.servers[serverName] = server + s.mutex.Unlock() + } + return server +} + +// ServerStatistics contains information about our interactions with a +// remote federated host, e.g. how many times we were successful, how +// many times we failed etc. It also manages the backoff time and black- +// listing a remote host if it remains uncooperative. +type ServerStatistics struct { + blacklisted atomic.Bool // is the remote side dead? + backoffUntil atomic.Value // time.Time to wait until before sending requests + failCounter atomic.Uint32 // how many times have we failed? + successCounter atomic.Uint32 // how many times have we succeeded? +} + +// Success updates the server statistics with a new successful +// attempt, which increases the sent counter and resets the idle and +// failure counters. If a host was blacklisted at this point then +// we will unblacklist it. +func (s *ServerStatistics) Success() { + s.successCounter.Add(1) + s.failCounter.Store(0) + s.blacklisted.Store(false) +} + +// Failure marks a failure and works out when to backoff until. It +// returns true if the worker should give up altogether because of +// too many consecutive failures. At this point the host is marked +// as blacklisted. +func (s *ServerStatistics) Failure() bool { + // Increase the fail counter. + failCounter := s.failCounter.Add(1) + + // Check that we haven't failed more times than is acceptable. + if failCounter >= FailuresUntilBlacklist { + // We've exceeded the maximum amount of times we're willing + // to back off, which is probably in the region of hours by + // now. Mark the host as blacklisted and tell the caller to + // give up. + s.blacklisted.Store(true) + return true + } + + // We're still under the threshold so work out the exponential + // backoff based on how many times we have failed already. The + // worker goroutine will wait until this time before processing + // anything from the queue. + backoffSeconds := time.Second * time.Duration(math.Exp2(float64(failCounter))) + s.backoffUntil.Store( + time.Now().Add(backoffSeconds), + ) + return false +} + +// BackoffDuration returns both a bool stating whether to wait, +// and then if true, a duration to wait for. +func (s *ServerStatistics) BackoffDuration() (bool, time.Duration) { + backoff, until := false, time.Second + if b, ok := s.backoffUntil.Load().(time.Time); ok { + if b.After(time.Now()) { + backoff, until = true, time.Until(b) + } + } + return backoff, until +} + +// Blacklisted returns true if the server is blacklisted and false +// otherwise. +func (s *ServerStatistics) Blacklisted() bool { + return s.blacklisted.Load() +} + +// SuccessCount returns the number of successful requests. This is +// usually useful in constructing transaction IDs. +func (s *ServerStatistics) SuccessCount() uint32 { + return s.successCounter.Load() +} diff --git a/federationsender/types/types.go b/federationsender/types/types.go index 05ba92f77..398d32677 100644 --- a/federationsender/types/types.go +++ b/federationsender/types/types.go @@ -28,6 +28,12 @@ type JoinedHost struct { ServerName gomatrixserverlib.ServerName } +type ServerNames []gomatrixserverlib.ServerName + +func (s ServerNames) Len() int { return len(s) } +func (s ServerNames) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s ServerNames) Less(i, j int) bool { return s[i] < s[j] } + // A EventIDMismatchError indicates that we have got out of sync with the // room server. type EventIDMismatchError struct { diff --git a/go.mod b/go.mod index 8d91902d1..6bfce8441 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,11 @@ module github.com/matrix-org/dendrite require ( + github.com/Shopify/sarama v1.26.1 + github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect + github.com/gologme/log v1.2.0 github.com/gorilla/mux v1.7.3 github.com/hashicorp/golang-lru v0.5.4 - github.com/kr/pretty v0.2.0 // indirect github.com/lib/pq v1.2.0 github.com/libp2p/go-libp2p v0.6.0 github.com/libp2p/go-libp2p-circuit v0.1.4 @@ -13,30 +15,29 @@ require ( github.com/libp2p/go-libp2p-kad-dht v0.5.0 github.com/libp2p/go-libp2p-pubsub v0.2.5 github.com/libp2p/go-libp2p-record v0.1.2 + github.com/libp2p/go-yamux v1.3.7 github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 - github.com/matrix-org/go-http-js-libp2p v0.0.0-20200318135427-31631a9ef51f - github.com/matrix-org/go-sqlite3-js v0.0.0-20200325174927-327088cdef10 + github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 + github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 - github.com/matrix-org/gomatrixserverlib v0.0.0-20200421090225-4ea81b29f5f7 - github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1 + github.com/matrix-org/gomatrixserverlib v0.0.0-20200623103809-13ff8109e137 + github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 - github.com/mattn/go-sqlite3 v2.0.3+incompatible + github.com/mattn/go-sqlite3 v2.0.2+incompatible github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 github.com/opentracing/opentracing-go v1.1.0 - github.com/pierrec/lz4 v2.5.0+incompatible // indirect github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.4.1 - github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect - github.com/sirupsen/logrus v1.4.2 - github.com/tidwall/gjson v1.6.0 // indirect - github.com/tidwall/pretty v1.0.1 // indirect - github.com/uber/jaeger-client-go v2.22.1+incompatible - github.com/uber/jaeger-lib v2.2.0+incompatible - go.uber.org/atomic v1.6.0 + github.com/sirupsen/logrus v1.6.0 + github.com/tidwall/gjson v1.6.0 + github.com/tidwall/sjson v1.0.3 + github.com/uber-go/atomic v1.3.0 // indirect + github.com/uber/jaeger-client-go v2.15.0+incompatible + github.com/uber/jaeger-lib v1.5.0 + github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200530233943-aec82d7a391b + go.uber.org/atomic v1.4.0 golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d - golang.org/x/tools v0.0.0-20200402223321-bcf690261a44 // indirect - gopkg.in/Shopify/sarama.v1 v1.20.1 gopkg.in/h2non/bimg.v1 v1.0.18 gopkg.in/yaml.v2 v2.2.8 ) diff --git a/go.sum b/go.sum index e76f6d007..6178f152b 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,15 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/AndreasBriese/bbloom v0.0.0-20180913140656-343706a395b7/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2FnnIvY/C0N1aqJo2kbEsdAzU+Tnv48= +github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/DataDog/zstd v1.4.4 h1:+IawcoXhCBylN7ccwdwf8LOH2jKq7NavGpEPanrlTzE= -github.com/DataDog/zstd v1.4.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y= +github.com/Shopify/sarama v1.26.1 h1:3jnfWKD7gVwbB1KSy/lE0szA9duPuSFLViK0o/d3DgA= +github.com/Shopify/sarama v1.26.1/go.mod h1:NbSGBSSndYaIhRcBtY9V0U7AyH+x71bG668AuWys/yU= github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= @@ -28,16 +30,14 @@ github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13P github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= github.com/btcsuite/btcutil v0.0.0-20190207003914-4c204d697803/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= -github.com/btcsuite/btcutil v1.0.1 h1:GKOz8BnRjYrb/JTKgaOk+zh26NWNdSNvdvv0xoAZMSA= -github.com/btcsuite/btcutil v1.0.1/go.mod h1:j9HUFwoQRsZL3V4n+qG+CUnEGHOarIxfC3Le2Yhbcts= github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg= github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd/go.mod h1:F+uVaaLLH7j4eDXPRvw78tMflu7Ie2bzYOH4Y8rRKBY= github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723/go.mod h1:8woku9dyThutzjeg+3xrA5iCpBRH8XEEg3lh6TiUghc= github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY= github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs= -github.com/cespare/xxhash/v2 v2.1.0/go.mod h1:dgIUBU3pDso/gPgZ1osOZ0iQf77oPR28Tjxl5dIMyVM= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cheggaaa/pb/v3 v3.0.4/go.mod h1:7rgWxLrAUcFMkvJuv09+DYi7mMUYi8nO9iOWcvGJPfw= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd h1:qMd81Ts1T2OTKmB4acZcyKaMtRnY5Y44NuXGX2GFJ1w= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= @@ -61,19 +61,19 @@ github.com/dgryski/go-farm v0.0.0-20190104051053-3adb47b1fb0f/go.mod h1:SqUrOPUn github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q= -github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q= -github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.7.2 h1:2QxQoC1TS09S7fhCPsrvqYdvP1H5M1P1ih5ABm3BTYk= github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= @@ -97,6 +97,9 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gologme/log v0.0.0-20181207131047-4e5d8ccb38e8/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= +github.com/gologme/log v1.2.0 h1:Ya5Ip/KD6FX7uH0S31QO87nCCSucKtF44TLbTtO7V4c= +github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= @@ -119,12 +122,16 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= +github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= +github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hjson/hjson-go v3.0.2-0.20200316202735-d5d0e8b0617d+incompatible/go.mod h1:qsetwF8NlsTsOTwZTApNlTCerV+b2GjYRRcIk4JMFio= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huin/goupnp v1.0.0 h1:wg75sLpL6DZqwHQN6E1Cfk6mtfzS45z8OV+ic+DtHRo= @@ -140,7 +147,6 @@ github.com/ipfs/go-cid v0.0.5/go.mod h1:plgt+Y5MnOey4vO4UlUazGqdbEXuFYitED67Fexh github.com/ipfs/go-datastore v0.0.1/go.mod h1:d4KVXhMt913cLBEI/PXAy6ko+W7e9AhyAKBGh803qeE= github.com/ipfs/go-datastore v0.1.0/go.mod h1:d4KVXhMt913cLBEI/PXAy6ko+W7e9AhyAKBGh803qeE= github.com/ipfs/go-datastore v0.1.1/go.mod h1:w38XXW9kVFNp57Zj5knbKWM2T+KOZCGDRVNdgPHtbHw= -github.com/ipfs/go-datastore v0.3.1 h1:SS1t869a6cctoSYmZXUk8eL6AzVXgASmKIWFNQkQ1jU= github.com/ipfs/go-datastore v0.3.1/go.mod h1:w38XXW9kVFNp57Zj5knbKWM2T+KOZCGDRVNdgPHtbHw= github.com/ipfs/go-datastore v0.4.0/go.mod h1:SX/xMIKoCszPqp+z9JhPYCmoOoXTvaa13XEbGtsFUhA= github.com/ipfs/go-datastore v0.4.1/go.mod h1:SX/xMIKoCszPqp+z9JhPYCmoOoXTvaa13XEbGtsFUhA= @@ -158,7 +164,6 @@ github.com/ipfs/go-ds-leveldb v0.4.1/go.mod h1:jpbku/YqBSsBc1qgME8BkWS4AxzF2cEu1 github.com/ipfs/go-ipfs-delay v0.0.0-20181109222059-70721b86a9a8/go.mod h1:8SP1YXK1M1kXuc4KJZINY3TQQ03J2rwBG9QfXmbRPrw= github.com/ipfs/go-ipfs-util v0.0.1 h1:Wz9bL2wB2YBJqggkA4dD7oSmqB4cAnpNbGrlHJulv50= github.com/ipfs/go-ipfs-util v0.0.1/go.mod h1:spsl5z8KUnrve+73pOhSVZND1SIxPW5RyBCNzQxlJBc= -github.com/ipfs/go-log v0.0.1 h1:9XTUN/rW64BCG1YhPK9Hoy3q8nr4gOmHHBpgFdfw6Lc= github.com/ipfs/go-log v0.0.1/go.mod h1:kL1d2/hzSpI0thNYjiKfjanbVNU+IIGA/WnNESY9leM= github.com/ipfs/go-log v1.0.2 h1:s19ZwJxH8rPWzypjcDpqPLIyV7BnbLqvpli3iZoqYK0= github.com/ipfs/go-log v1.0.2/go.mod h1:1MNjMxe0u6xvJZgeqbJ8vdo2TKaGwZ1a0Bpza+sr2Sk= @@ -178,25 +183,30 @@ github.com/jbenet/go-temp-err-catcher v0.0.0-20150120210811-aac704a3f4f2/go.mod github.com/jbenet/goprocess v0.0.0-20160826012719-b497e2f366b8/go.mod h1:Ly/wlsjFq/qrU3Rar62tu1gASgGw6chQbSh/XgIIXCY= github.com/jbenet/goprocess v0.1.3 h1:YKyIEECS/XvcfHtBzxtjBBbWK+MbvA6dG8ASiqwvr10= github.com/jbenet/goprocess v0.1.3/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= +github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8= +github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kami-zh/go-capturer v0.0.0-20171211120116-e492ea43421d/go.mod h1:P2viExyCEfeWGU259JnaQ34Inuec4R38JCyBx2edgD0= +github.com/kardianos/minwinsvc v0.0.0-20151122163309-cad6b2b879b0/go.mod h1:rUi0/YffDo1oXBOGn1KRq7Fr07LX48XEBecQnmwjsAo= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= -github.com/kisielk/gotool v1.0.0 h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= +github.com/klauspost/compress v1.9.8 h1:VMAMUUOh+gaxKTMk+zqbjsSjsIcUcL/LF4o63i82QyA= +github.com/klauspost/compress v1.9.8/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d h1:68u9r4wEvL3gYg2jvAOgROwZ3H+Y3hIDk4tbbmIjcYQ= github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d/go.mod h1:5Ky9EC2xfoUKUor0Hjgi2BJhCSXJfMOFlmyYrVKGQMk= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -344,49 +354,44 @@ github.com/libp2p/go-yamux v1.2.3 h1:xX8A36vpXb59frIzWFdEgptLMsOANMFq2K7fPRlunYI github.com/libp2p/go-yamux v1.2.3/go.mod h1:FGTiPvoV/3DVdgWpX+tM0OW3tsM+W5bSE3gZwqQTcow= github.com/libp2p/go-yamux v1.3.0 h1:FsYzT16Wq2XqUGJsBbOxoz9g+dFklvNi7jN6YFPfl7U= github.com/libp2p/go-yamux v1.3.0/go.mod h1:FGTiPvoV/3DVdgWpX+tM0OW3tsM+W5bSE3gZwqQTcow= +github.com/libp2p/go-yamux v1.3.7 h1:v40A1eSPJDIZwz2AvrV3cxpTZEGDP11QJbukmEhYyQI= +github.com/libp2p/go-yamux v1.3.7/go.mod h1:fr7aVgmdNGJK+N1g+b6DW6VxzbRCjCOejR/hkmpooHE= +github.com/lxn/walk v0.0.0-20191128110447-55ccb3a9f5c1/go.mod h1:E23UucZGqpuUANJooIbHWCufXvOcT6E7Stq81gU+CSQ= +github.com/lxn/win v0.0.0-20191128105842-2da648fda5b4/go.mod h1:ouWl4wViUNh8tPSIwxTVMuS014WakR1hqvBc2I0bMoA= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/matrix-org/dendrite v0.0.0-20200220135450-0352f250b857/go.mod h1:DZ35IoR+ViBNVPe9umdlOSnjvKl7wfyRmZg4QfWGvTo= github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 h1:nMX2t7hbGF0NYDYySx0pCqEKGKAeZIiSqlWSspetlhY= github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 h1:nMX2t7hbGF0NYDYySx0pCqEKGKAeZIiSqlWSspetlhY= github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= -github.com/matrix-org/go-http-js-libp2p v0.0.0-20200318135427-31631a9ef51f h1:5TOte9uk/epk8L+Pbp6qwaV8YsKYXKjyECPHUhJTWQc= -github.com/matrix-org/go-http-js-libp2p v0.0.0-20200318135427-31631a9ef51f/go.mod h1:qK3LUW7RCLhFM7gC3pabj3EXT9A1DsCK33MHstUhhbk= -github.com/matrix-org/go-sqlite3-js v0.0.0-20200304164012-aa524245b658 h1:UlhTKClOgWnSB25Rv+BS/Vc1mRinjNUErfyGEVOBP04= -github.com/matrix-org/go-sqlite3-js v0.0.0-20200304164012-aa524245b658/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= -github.com/matrix-org/go-sqlite3-js v0.0.0-20200325174927-327088cdef10 h1:SnhC7/o87ueVwEWI3mUYtrs+s8VnYq3KZtpWsFQOLFE= -github.com/matrix-org/go-sqlite3-js v0.0.0-20200325174927-327088cdef10/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= -github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af h1:piaIBNQGIHnni27xRB7VKkEwoWCgAmeuYf8pxAyG0bI= -github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= +github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 h1:eqE5OnGx9ZMWmrRbD3KF/3KtTunw0iQulI7YxOIdxo4= +github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4/go.mod h1:3WluEZ9QXSwU30tWYqktnpC1x9mwZKx1r8uAv8Iq+a4= +github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 h1:Yb+Wlf/iHhWlLWd+kCgG+Fsg4Dc+xBl7hptfK7lD0zY= +github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5 h1:kmRjpmFOenVpOaV/DRlo9p6z/IbOKlUC+hhKsAAh8Qg= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5/go.mod h1:FsKa2pWE/bpQql9H7U4boOPXFoJX/QcqaZZ6ijLkaZI= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200421090225-4ea81b29f5f7 h1:4vE84tE3r7BitCt2HQvT231JrhMjDfjDVDqVoiVPv0w= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200421090225-4ea81b29f5f7/go.mod h1:FsKa2pWE/bpQql9H7U4boOPXFoJX/QcqaZZ6ijLkaZI= -github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1 h1:osLoFdOy+ChQqVUn2PeTDETFftVkl4w9t/OW18g3lnk= -github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A= -github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 h1:W7l5CP4V7wPyPb4tYE11dbmeAOwtFQBTW0rf4OonOS8= -github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5/go.mod h1:lePuOiXLNDott7NZfnQvJk0lAZ5HgvIuWGhel6J+RLA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200623103809-13ff8109e137 h1:+eBh4L04+08IslvFM071TNrQTggU317GsQKzZ1SGEVo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200623103809-13ff8109e137/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= +github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5 h1:tHXDdz1cpzGaovsTB+TVB8q90WEokoVmfMqoVcrLUgw= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= +github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-sqlite3 v2.0.2+incompatible h1:qzw9c2GNT8UFrgWNDhCTqRqYUSmu/Dav/9Z58LGpk7U= github.com/mattn/go-sqlite3 v2.0.2+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= -github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= -github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0= -github.com/miekg/dns v1.1.4/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.12 h1:WMhc1ik4LNkTg8U9l3hI1LvxKmIL+f1+WV/SZtCbDDA= github.com/miekg/dns v1.1.12/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 h1:lYpkrQH5ajf0OXOcUbGjvZxxijuBwbbmlSxLiuofa+g= @@ -475,8 +480,6 @@ github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFSt github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pierrec/lz4 v2.4.1+incompatible h1:mFe7ttWaflA46Mhqh+jUfjp2qTbPYxLB2/OyBppH9dg= github.com/pierrec/lz4 v2.4.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pierrec/lz4 v2.5.0+incompatible h1:MbdIZ43A//duwOjQqK3nP+up+65yraNFyX3Vp6Rwues= -github.com/pierrec/lz4 v2.5.0+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -486,35 +489,27 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= -github.com/prometheus/client_golang v1.2.1/go.mod h1:XMU6Z2MjaRKVu/dC1qupJI9SiNkDYzz3xecMgSW/F+U= github.com/prometheus/client_golang v1.4.1 h1:FFSuS004yOQEtDdTq+TAOLP5xUq63KqAFYyOi8zA+Y8= github.com/prometheus/client_golang v1.4.1/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= github.com/prometheus/common v0.9.1 h1:KOMtN28tlbam3/7ZKEYKHhKoJZYYj3gMH4uc62x7X7U= github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.0.5/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ= github.com/prometheus/procfs v0.0.8 h1:+fpWZdT24pJBiqJdAwYBjPSk+5YmQzYNPYzQsdzLkt8= github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 h1:dY6ETXrvDG7Sa4vE8ZQG4yqWg6UnOcbqTAahkV813vQ= github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ= -github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ= -github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= -github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/smola/gocompat v0.2.0/go.mod h1:1B0MlxbmoZNo3h8guHp8HztB3BSYR5itql9qtVc0ypY= github.com/spacemonkeygo/openssl v0.0.0-20181017203307-c2dcc5cca94a/go.mod h1:7AyxJNCJ7SBZ1MfVQCWD6Uqo2oubI2Eq2y2eqf+A5r0= github.com/spacemonkeygo/spacelog v0.0.0-20180420211403-2296661a0572 h1:RC6RW7j+1+HkWaX/Yh71Ee5ZHaHYt7ZP4sQgUrm6cDU= @@ -535,8 +530,6 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= -github.com/tidwall/gjson v1.1.5 h1:QysILxBeUEY3GTLA0fQVgkQG1zme8NxGvhh2SSqWNwI= -github.com/tidwall/gjson v1.1.5/go.mod h1:c/nTNbUr0E0OrXEhq1pwa8iEgc2DOt4ZZqAt1HtCkPA= github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc= @@ -550,17 +543,14 @@ github.com/uber-go/atomic v1.3.0 h1:ylWoWcs+jXihgo3Us1Sdsatf2R6+OlBGm8fexR3oFG4= github.com/uber-go/atomic v1.3.0/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g= github.com/uber/jaeger-client-go v2.15.0+incompatible h1:NP3qsSqNxh8VYr956ur1N/1C1PjvOJnJykCzcD5QHbk= github.com/uber/jaeger-client-go v2.15.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= -github.com/uber/jaeger-client-go v2.22.1+incompatible h1:NHcubEkVbahf9t3p75TOCR83gdUHXjRJvjoBh1yACsM= -github.com/uber/jaeger-client-go v2.22.1+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-lib v1.5.0 h1:OHbgr8l656Ub3Fw5k9SWnBfIEwvoHQ+W2y+Aa9D1Uyo= github.com/uber/jaeger-lib v1.5.0/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= -github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw= -github.com/uber/jaeger-lib v2.2.0+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= +github.com/vishvananda/netlink v1.0.0/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= +github.com/vishvananda/netns v0.0.0-20190625233234-7109fa855b0f/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k= github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1/go.mod h1:8UvriyWtv5Q5EOgjHaSseUEdkQfvwFv1I/In/O2M9gc= github.com/whyrusleeping/go-logging v0.0.0-20170515211332-0457bb6b88fc/go.mod h1:bopw91TMyo8J3tvftk8xmU2kPmlrt4nScJQZU2hE5EM= -github.com/whyrusleeping/go-logging v0.0.1 h1:fwpzlmT0kRC/Fmd0MdmGgJG/CXIZ6gFq46FQZjprUcc= github.com/whyrusleeping/go-logging v0.0.1/go.mod h1:lDPYj54zutzG1XYfHAhcc7oNXEburHQBn+Iqd4yS4vE= github.com/whyrusleeping/mafmt v1.2.8 h1:TCghSl5kkwEE0j+sU/gudyhVMRlpBin8fMBBHg59EbA= github.com/whyrusleeping/mafmt v1.2.8/go.mod h1:faQJFPbLSxzD9xpA02ttW/tS9vZykNvXwGvqIpk20FA= @@ -571,18 +561,19 @@ github.com/whyrusleeping/multiaddr-filter v0.0.0-20160516205228-e903e4adabd7/go. github.com/whyrusleeping/timecache v0.0.0-20160911033111-cfcb2f1abfee h1:lYbXeSvJi5zk5GLKVuid9TVjS9a0OmLIDKTfoZBL6Ow= github.com/whyrusleeping/timecache v0.0.0-20160911033111-cfcb2f1abfee/go.mod h1:m2aV4LZI4Aez7dP5PMyVKEHhUyEJ/RjmPEDOpDvudHg= github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= +github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yggdrasil-network/yggdrasil-extras v0.0.0-20200525205615-6c8a4a2e8855/go.mod h1:xQdsh08Io6nV4WRnOVTe6gI8/2iTvfLDQ0CYa5aMt+I= +github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200530233943-aec82d7a391b h1:ELOisSxFXCcptRs4LFub+Hz5fYUvV12wZrTps99Eb3E= +github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200530233943-aec82d7a391b/go.mod h1:d+Nz6SPeG6kmeSPFL0cvfWfgwEql75fUnZiAONgvyBE= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.1/go.mod h1:Ap50jQcDJrx6rB6VgeeFPtuPIf3wMRvRfrfYDO6+BmA= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3 h1:8sGtKOrtQqkN1bp2AtX+misvLIlOmsEsNd+9NIcPEm8= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.uber.org/atomic v1.3.0 h1:vs7fgriifsPbGdK3bNuMWapNn3qnZhCRXc19NRdq010= -go.uber.org/atomic v1.3.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= -go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM= @@ -591,8 +582,6 @@ golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613 h1:MQ/ZZiDsUapFFiMS+vzwXkCTeEKaum+Do5rINYJDmxc= -golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190225124518-7f87c0fbb88b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= @@ -601,20 +590,16 @@ golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190618222545-ea8f1a30c443/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d h1:2+ZP7EfsZV7Vvmx3TIqSlSzATMkTAKqM14YGFPoSKjI= -golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200204104054-c9f3fb736b72/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d h1:1ZiEyfaQIg3Qh0EoqpwAakHVhecoE5wlSg5GjnafJGw= golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.2.0 h1:KU7oHjnv3XNWfa5COkzUifxZmxp1TyI7ImMXqFxLwvQ= -golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -622,15 +607,17 @@ golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190227160552-c95aed5357e7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95 h1:fY7Dsw114eJN4boqzVSbpVHO6rTdhq6/GnXeu+PKnzU= -golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a h1:GuSPYbZzB5/dcLNCwLQLsg3obCJtX9IJhpXkvY7kzk0= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -655,33 +642,40 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/p golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82 h1:ywK/j/KkyTHcdyYSZNXGjMwgmDSfjglYZ3vStQ/gSCU= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200301040627-c5d0d7b4ec88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So= +golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3-0.20191230102452-929e72ca90de h1:aYKJLPSrddB2N7/6OKyFqJ337SXpo61bBuvO5p1+7iY= +golang.org/x/text v0.3.3-0.20191230102452-929e72ca90de/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181130052023-1c3d964395ce/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd h1:/e+gpKk9r3dJobndpTytxS2gOy6m5uvpg+ISQoEcusQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c h1:IGkKhmfzcztjm6gYkykvu/NiS8kaqbCWAEWWAyf8J5U= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200402223321-bcf690261a44 h1:bMm0eoDiGkM5VfIyKjxDvoflW5GLp7+VCo+60n8F+TE= -golang.org/x/tools v0.0.0-20200402223321-bcf690261a44/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/wireguard v0.0.20200122-0.20200214175355-9cbcff10dd3e/go.mod h1:P2HsVp8SKwZEufsnezXZA4GRX/T49/HlU7DGuelXsU4= +golang.zx2c4.com/wireguard v0.0.20200320/go.mod h1:lDian4Sw4poJ04SgHh35nzMVwGSYlPumkdnHcucAQoY= +golang.zx2c4.com/wireguard/windows v0.1.0/go.mod h1:EK7CxrFnicmYJ0ZCF6crBh2/EMMeSxMlqgLlwN0Kv9s= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -690,8 +684,6 @@ google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRn google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -gopkg.in/Shopify/sarama.v1 v1.20.1 h1:Gi09A3fJXm0Jgt8kuKZ8YK+r60GfYn7MQuEmI3oq6hE= -gopkg.in/Shopify/sarama.v1 v1.20.1/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc= gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= @@ -707,6 +699,16 @@ gopkg.in/h2non/bimg.v1 v1.0.18 h1:qn6/RpBHt+7WQqoBcK+aF2puc6nC78eZj5LexxoalT4= gopkg.in/h2non/bimg.v1 v1.0.18/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So= gopkg.in/h2non/gock.v1 v1.0.14 h1:fTeu9fcUvSnLNacYvYI54h+1/XEteDyHvrVCZEEEYNM= gopkg.in/h2non/gock.v1 v1.0.14/go.mod h1:sX4zAkdYX1TRGJ2JY156cFspQn4yRWn6p9EMdODlynE= +gopkg.in/jcmturner/aescts.v1 v1.0.1 h1:cVVZBK2b1zY26haWB4vbBiZrfFQnfbTVrE3xZq6hrEw= +gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo= +gopkg.in/jcmturner/dnsutils.v1 v1.0.1 h1:cIuC1OLRGZrld+16ZJvvZxVJeKPsvd5eUIvxfoN5hSM= +gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q= +gopkg.in/jcmturner/goidentity.v3 v3.0.0 h1:1duIyWiTaYvVx3YX2CYtpJbUFd7/UuPYCfgXtQ3VTbI= +gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4= +gopkg.in/jcmturner/gokrb5.v7 v7.5.0 h1:a9tsXlIDD9SKxotJMK3niV7rPZAJeX2aD/0yg3qlIrg= +gopkg.in/jcmturner/gokrb5.v7 v7.5.0/go.mod h1:l8VISx+WGYp+Fp7KRbsiUuXTTOnxIc3Tuvyavf11/WM= +gopkg.in/jcmturner/rpc.v1 v1.1.0 h1:QHIUxTX1ISuAv9dD2wJ9HWQVuWDX/Zc0PfeC2tjc4rU= +gopkg.in/jcmturner/rpc.v1 v1.1.0/go.mod h1:YIdkC4XfD6GXbzje11McwsDuOlZQSb9W4vfLvuNnlv8= gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI= gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o= gopkg.in/src-d/go-cli.v0 v0.0.0-20181105080154-d492247bbc0d/go.mod h1:z+K8VcOYVYcSwSjGebuDL6176A1XskgbtNl64NSg+n8= diff --git a/hooks/install.sh b/hooks/install.sh deleted file mode 100755 index f8aa331ff..000000000 --- a/hooks/install.sh +++ /dev/null @@ -1,5 +0,0 @@ -#! /bin/bash - -DOT_GIT="$(dirname $0)/../.git" - -ln -s "../../hooks/pre-commit" "$DOT_GIT/hooks/pre-commit" \ No newline at end of file diff --git a/hooks/pre-commit b/hooks/pre-commit deleted file mode 100755 index 6f98b813c..000000000 --- a/hooks/pre-commit +++ /dev/null @@ -1,22 +0,0 @@ -#! /bin/bash - -set -eu - -# make the GIT_DIR and GIT_INDEX_FILE absolute, before we change dir -export GIT_DIR=$(readlink -f `git rev-parse --git-dir`) -if [ -n "${GIT_INDEX_FILE:+x}" ]; then - export GIT_INDEX_FILE=$(readlink -f "$GIT_INDEX_FILE") -fi - -# create a temp dir. The `trap` incantation will ensure that it is removed -# again when this script completes. -tmpdir=`mktemp -d` -trap 'rm -rf "$tmpdir"' EXIT -cd "$tmpdir" - -# get a clean copy of the index (ie, what has been `git add`ed), so that we can -# run the checks against what we are about to commit, rather than what is in -# the working copy. -git checkout-index -a - -./scripts/find-lint.sh fast diff --git a/internal/caching/cache_roomversions.go b/internal/caching/cache_roomversions.go new file mode 100644 index 000000000..0b46d3d4b --- /dev/null +++ b/internal/caching/cache_roomversions.go @@ -0,0 +1,30 @@ +package caching + +import "github.com/matrix-org/gomatrixserverlib" + +const ( + RoomVersionCacheName = "room_versions" + RoomVersionCacheMaxEntries = 1024 + RoomVersionCacheMutable = false +) + +// RoomVersionsCache contains the subset of functions needed for +// a room version cache. +type RoomVersionCache interface { + GetRoomVersion(roomID string) (roomVersion gomatrixserverlib.RoomVersion, ok bool) + StoreRoomVersion(roomID string, roomVersion gomatrixserverlib.RoomVersion) +} + +func (c Caches) GetRoomVersion(roomID string) (gomatrixserverlib.RoomVersion, bool) { + val, found := c.RoomVersions.Get(roomID) + if found && val != nil { + if roomVersion, ok := val.(gomatrixserverlib.RoomVersion); ok { + return roomVersion, true + } + } + return "", false +} + +func (c Caches) StoreRoomVersion(roomID string, roomVersion gomatrixserverlib.RoomVersion) { + c.RoomVersions.Set(roomID, roomVersion) +} diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go new file mode 100644 index 000000000..4697fb4d2 --- /dev/null +++ b/internal/caching/cache_serverkeys.go @@ -0,0 +1,56 @@ +package caching + +import ( + "fmt" + + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + ServerKeyCacheName = "server_key" + ServerKeyCacheMaxEntries = 4096 + ServerKeyCacheMutable = true +) + +// ServerKeyCache contains the subset of functions needed for +// a server key cache. +type ServerKeyCache interface { + // request -> timestamp is emulating gomatrixserverlib.FetchKeys: + // https://github.com/matrix-org/gomatrixserverlib/blob/f69539c86ea55d1e2cc76fd8e944e2d82d30397c/keyring.go#L95 + // The timestamp should be the timestamp of the event that is being + // verified. We will not return keys from the cache that are not valid + // at this timestamp. + GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest, timestamp gomatrixserverlib.Timestamp) (response gomatrixserverlib.PublicKeyLookupResult, ok bool) + + // request -> result is emulating gomatrixserverlib.StoreKeys: + // https://github.com/matrix-org/gomatrixserverlib/blob/f69539c86ea55d1e2cc76fd8e944e2d82d30397c/keyring.go#L112 + StoreServerKey(request gomatrixserverlib.PublicKeyLookupRequest, response gomatrixserverlib.PublicKeyLookupResult) +} + +func (c Caches) GetServerKey( + request gomatrixserverlib.PublicKeyLookupRequest, + timestamp gomatrixserverlib.Timestamp, +) (gomatrixserverlib.PublicKeyLookupResult, bool) { + key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) + val, found := c.ServerKeys.Get(key) + if found && val != nil { + if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok { + if !keyLookupResult.WasValidAt(timestamp, true) { + // The key wasn't valid at the requested timestamp so don't + // return it. The caller will have to work out what to do. + c.ServerKeys.Unset(key) + return gomatrixserverlib.PublicKeyLookupResult{}, false + } + return keyLookupResult, true + } + } + return gomatrixserverlib.PublicKeyLookupResult{}, false +} + +func (c Caches) StoreServerKey( + request gomatrixserverlib.PublicKeyLookupRequest, + response gomatrixserverlib.PublicKeyLookupResult, +) { + key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) + c.ServerKeys.Set(key, response) +} diff --git a/internal/caching/caches.go b/internal/caching/caches.go new file mode 100644 index 000000000..419623e27 --- /dev/null +++ b/internal/caching/caches.go @@ -0,0 +1,16 @@ +package caching + +// Caches contains a set of references to caches. They may be +// different implementations as long as they satisfy the Cache +// interface. +type Caches struct { + RoomVersions Cache // implements RoomVersionCache + ServerKeys Cache // implements ServerKeyCache +} + +// Cache is the interface that an implementation must satisfy. +type Cache interface { + Get(key string) (value interface{}, ok bool) + Set(key string, value interface{}) + Unset(key string) +} diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go new file mode 100644 index 000000000..7bb791dd8 --- /dev/null +++ b/internal/caching/impl_inmemorylru.go @@ -0,0 +1,84 @@ +package caching + +import ( + "fmt" + + lru "github.com/hashicorp/golang-lru" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { + roomVersions, err := NewInMemoryLRUCachePartition( + RoomVersionCacheName, + RoomVersionCacheMutable, + RoomVersionCacheMaxEntries, + enablePrometheus, + ) + if err != nil { + return nil, err + } + serverKeys, err := NewInMemoryLRUCachePartition( + ServerKeyCacheName, + ServerKeyCacheMutable, + ServerKeyCacheMaxEntries, + enablePrometheus, + ) + if err != nil { + return nil, err + } + return &Caches{ + RoomVersions: roomVersions, + ServerKeys: serverKeys, + }, nil +} + +type InMemoryLRUCachePartition struct { + name string + mutable bool + maxEntries int + lru *lru.Cache +} + +func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) { + var err error + cache := InMemoryLRUCachePartition{ + name: name, + mutable: mutable, + maxEntries: maxEntries, + } + cache.lru, err = lru.New(maxEntries) + if err != nil { + return nil, err + } + if enablePrometheus { + promauto.NewGaugeFunc(prometheus.GaugeOpts{ + Namespace: "dendrite", + Subsystem: "caching_in_memory_lru", + Name: name, + }, func() float64 { + return float64(cache.lru.Len()) + }) + } + return &cache, nil +} + +func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) { + if !c.mutable { + if peek, ok := c.lru.Peek(key); ok && peek != value { + panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) + } + } + c.lru.Add(key, value) +} + +func (c *InMemoryLRUCachePartition) Unset(key string) { + if !c.mutable { + panic(fmt.Sprintf("invalid use of immutable cache tries to unset value of %q", key)) + } + c.lru.Remove(key) +} + +func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) { + return c.lru.Get(key) +} diff --git a/common/config/appservice.go b/internal/config/appservice.go similarity index 100% rename from common/config/appservice.go rename to internal/config/appservice.go diff --git a/common/config/config.go b/internal/config/config.go similarity index 89% rename from common/config/config.go rename to internal/config/config.go index 6b61fda7c..baa82be23 100644 --- a/common/config/config.go +++ b/internal/config/config.go @@ -152,8 +152,8 @@ type Dendrite struct { OutputClientData Topic `yaml:"output_client_data"` // Topic for eduserver/api.OutputTypingEvent events. OutputTypingEvent Topic `yaml:"output_typing_event"` - // Topic for user updates (profile, presence) - UserUpdates Topic `yaml:"user_updates"` + // Topic for eduserver/api.OutputSendToDeviceEvent events. + OutputSendToDeviceEvent Topic `yaml:"output_send_to_device_event"` } } `yaml:"kafka"` @@ -188,6 +188,12 @@ type Dendrite struct { PublicRoomsAPI DataSource `yaml:"public_rooms_api"` // The Naffka database is used internally by the naffka library, if used. Naffka DataSource `yaml:"naffka,omitempty"` + // Maximum open connections to the DB (0 = use default, negative means unlimited) + MaxOpenConns int `yaml:"max_open_conns"` + // Maximum idle connections to the DB (0 = use default, negative means unlimited) + MaxIdleConns int `yaml:"max_idle_conns"` + // maximum amount of time (in seconds) a connection may be reused (<= 0 means unlimited) + ConnMaxLifetimeSec int `yaml:"conn_max_lifetime"` } `yaml:"database"` // TURN Server Config @@ -217,12 +223,15 @@ type Dendrite struct { MediaAPI Address `yaml:"media_api"` ClientAPI Address `yaml:"client_api"` FederationAPI Address `yaml:"federation_api"` + ServerKeyAPI Address `yaml:"server_key_api"` AppServiceAPI Address `yaml:"appservice_api"` SyncAPI Address `yaml:"sync_api"` + UserAPI Address `yaml:"user_api"` RoomServer Address `yaml:"room_server"` FederationSender Address `yaml:"federation_sender"` PublicRoomsAPI Address `yaml:"public_rooms_api"` EDUServer Address `yaml:"edu_server"` + KeyServer Address `yaml:"key_server"` } `yaml:"bind"` // The addresses for talking to other microservices. @@ -230,12 +239,15 @@ type Dendrite struct { MediaAPI Address `yaml:"media_api"` ClientAPI Address `yaml:"client_api"` FederationAPI Address `yaml:"federation_api"` + ServerKeyAPI Address `yaml:"server_key_api"` AppServiceAPI Address `yaml:"appservice_api"` SyncAPI Address `yaml:"sync_api"` + UserAPI Address `yaml:"user_api"` RoomServer Address `yaml:"room_server"` FederationSender Address `yaml:"federation_sender"` PublicRoomsAPI Address `yaml:"public_rooms_api"` EDUServer Address `yaml:"edu_server"` + KeyServer Address `yaml:"key_server"` } `yaml:"listen"` // The config for tracing the dendrite servers. @@ -256,6 +268,16 @@ type Dendrite struct { // The config for logging informations. Each hook will be added to logrus. Logging []LogrusHook `yaml:"logging"` + // The config for setting a proxy to use for server->server requests + Proxy *struct { + // The protocol for the proxy (http / https / socks5) + Protocol string `yaml:"protocol"` + // The host where the proxy is listening + Host string `yaml:"host"` + // The port on which the proxy is listening + Port uint16 `yaml:"port"` + } `yaml:"proxy"` + // Any information derived from the configuration options for later use. Derived struct { Registration struct { @@ -348,11 +370,9 @@ type LogrusHook struct { // It implements the error interface. type configErrors []string -// Load a yaml config file for a server run as multiple processes. +// Load a yaml config file for a server run as multiple processes or as a monolith. // Checks the config to ensure that it is valid. -// The checks are different if the server is run as a monolithic process instead -// of being split into multiple components -func Load(configPath string) (*Dendrite, error) { +func Load(configPath string, monolith bool) (*Dendrite, error) { configData, err := ioutil.ReadFile(configPath) if err != nil { return nil, err @@ -363,27 +383,7 @@ func Load(configPath string) (*Dendrite, error) { } // Pass the current working directory and ioutil.ReadFile so that they can // be mocked in the tests - monolithic := false - return loadConfig(basePath, configData, ioutil.ReadFile, monolithic) -} - -// LoadMonolithic loads a yaml config file for a server run as a single monolith. -// Checks the config to ensure that it is valid. -// The checks are different if the server is run as a monolithic process instead -// of being split into multiple components -func LoadMonolithic(configPath string) (*Dendrite, error) { - configData, err := ioutil.ReadFile(configPath) - if err != nil { - return nil, err - } - basePath, err := filepath.Abs(".") - if err != nil { - return nil, err - } - // Pass the current working directory and ioutil.ReadFile so that they can - // be mocked in the tests - monolithic := true - return loadConfig(basePath, configData, ioutil.ReadFile, monolithic) + return loadConfig(basePath, configData, ioutil.ReadFile, monolith) } func loadConfig( @@ -484,6 +484,15 @@ func (config *Dendrite) SetDefaults() { defaultMaxFileSizeBytes := FileSizeBytes(10485760) config.Media.MaxFileSizeBytes = &defaultMaxFileSizeBytes } + + if config.Database.MaxIdleConns == 0 { + config.Database.MaxIdleConns = 2 + } + + if config.Database.MaxOpenConns == 0 { + config.Database.MaxOpenConns = 100 + } + } // Error returns a string detailing how many errors were contained within a @@ -582,7 +591,6 @@ func (config *Dendrite) checkKafka(configErrs *configErrors, monolithic bool) { checkNotEmpty(configErrs, "kafka.topics.output_room_event", string(config.Kafka.Topics.OutputRoomEvent)) checkNotEmpty(configErrs, "kafka.topics.output_client_data", string(config.Kafka.Topics.OutputClientData)) checkNotEmpty(configErrs, "kafka.topics.output_typing_event", string(config.Kafka.Topics.OutputTypingEvent)) - checkNotEmpty(configErrs, "kafka.topics.user_updates", string(config.Kafka.Topics.UserUpdates)) } // checkDatabase verifies the parameters database.* are valid. @@ -603,6 +611,8 @@ func (config *Dendrite) checkListen(configErrs *configErrors) { checkNotEmpty(configErrs, "listen.sync_api", string(config.Listen.SyncAPI)) checkNotEmpty(configErrs, "listen.room_server", string(config.Listen.RoomServer)) checkNotEmpty(configErrs, "listen.edu_server", string(config.Listen.EDUServer)) + checkNotEmpty(configErrs, "listen.server_key_api", string(config.Listen.EDUServer)) + checkNotEmpty(configErrs, "listen.user_api", string(config.Listen.UserAPI)) } // checkLogging verifies the parameters logging.* are valid. @@ -716,6 +726,15 @@ func (config *Dendrite) RoomServerURL() string { return "http://" + string(config.Listen.RoomServer) } +// UserAPIURL returns an HTTP URL for where the userapi is listening. +func (config *Dendrite) UserAPIURL() string { + // Hard code the userapi to talk HTTP for now. + // If we support HTTPS we need to think of a practical way to do certificate validation. + // People setting up servers shouldn't need to get a certificate valid for the public + // internet for an internal API. + return "http://" + string(config.Listen.UserAPI) +} + // EDUServerURL returns an HTTP URL for where the EDU server is listening. func (config *Dendrite) EDUServerURL() string { // Hard code the EDU server to talk HTTP for now. @@ -734,6 +753,15 @@ func (config *Dendrite) FederationSenderURL() string { return "http://" + string(config.Listen.FederationSender) } +// FederationSenderURL returns an HTTP URL for where the federation sender is listening. +func (config *Dendrite) ServerKeyAPIURL() string { + // Hard code the server key API server to talk HTTP for now. + // If we support HTTPS we need to think of a practical way to do certificate validation. + // People setting up servers shouldn't need to get a certificate valid for the public + // internet for an internal API. + return "http://" + string(config.Listen.ServerKeyAPI) +} + // SetupTracing configures the opentracing using the supplied configuration. func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err error) { if !config.Tracing.Enabled { @@ -746,6 +774,33 @@ func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err ) } +// MaxIdleConns returns maximum idle connections to the DB +func (config Dendrite) MaxIdleConns() int { + return config.Database.MaxIdleConns +} + +// MaxOpenConns returns maximum open connections to the DB +func (config Dendrite) MaxOpenConns() int { + return config.Database.MaxOpenConns +} + +// ConnMaxLifetime returns maximum amount of time a connection may be reused +func (config Dendrite) ConnMaxLifetime() time.Duration { + return time.Duration(config.Database.ConnMaxLifetimeSec) * time.Second +} + +// DbProperties functions return properties used by database/sql/DB +type DbProperties interface { + MaxIdleConns() int + MaxOpenConns() int + ConnMaxLifetime() time.Duration +} + +// DbProperties returns cfg as a DbProperties interface +func (config Dendrite) DbProperties() DbProperties { + return config +} + // logrusLogger is a small wrapper that implements jaeger.Logger using logrus. type logrusLogger struct { l *logrus.Logger diff --git a/common/config/config_test.go b/internal/config/config_test.go similarity index 99% rename from common/config/config_test.go rename to internal/config/config_test.go index b72f5fad0..9a543e763 100644 --- a/common/config/config_test.go +++ b/internal/config/config_test.go @@ -63,6 +63,7 @@ listen: media_api: "localhost:7774" appservice_api: "localhost:7777" edu_server: "localhost:7778" + user_api: "localhost:7779" logging: - type: "file" level: "info" diff --git a/common/consumers.go b/internal/consumers.go similarity index 93% rename from common/consumers.go rename to internal/consumers.go index f33993494..d7917f235 100644 --- a/common/consumers.go +++ b/internal/consumers.go @@ -12,27 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package internal import ( "context" "fmt" - sarama "gopkg.in/Shopify/sarama.v1" + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal/sqlutil" ) -// A PartitionOffset is the offset into a partition of the input log. -type PartitionOffset struct { - // The ID of the partition. - Partition int32 - // The offset into the partition. - Offset int64 -} - // A PartitionStorer has the storage APIs needed by the consumer. type PartitionStorer interface { // PartitionOffsets returns the offsets the consumer has reached for each partition. - PartitionOffsets(ctx context.Context, topic string) ([]PartitionOffset, error) + PartitionOffsets(ctx context.Context, topic string) ([]sqlutil.PartitionOffset, error) // SetPartitionOffset records where the consumer has reached for a partition. SetPartitionOffset(ctx context.Context, topic string, partition int32, offset int64) error } diff --git a/common/eventcontent.go b/internal/eventutil/eventcontent.go similarity index 92% rename from common/eventcontent.go rename to internal/eventutil/eventcontent.go index f3817ba68..873e20a8e 100644 --- a/common/eventcontent.go +++ b/internal/eventutil/eventcontent.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package eventutil import "github.com/matrix-org/gomatrixserverlib" @@ -36,6 +36,11 @@ type HistoryVisibilityContent struct { HistoryVisibility string `json:"history_visibility"` } +// CanonicalAlias is the event content for https://matrix.org/docs/spec/client_server/r0.6.0#m-room-canonical-alias +type CanonicalAlias struct { + Alias string `json:"alias"` +} + // InitialPowerLevelsContent returns the initial values for m.room.power_levels on room creation // if they have not been specified. // http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-power-levels diff --git a/common/events.go b/internal/eventutil/events.go similarity index 90% rename from common/events.go rename to internal/eventutil/events.go index 556b7b671..e6c7a4ff7 100644 --- a/common/events.go +++ b/internal/eventutil/events.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package eventutil import ( "context" @@ -20,7 +20,7 @@ import ( "fmt" "time" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" @@ -39,13 +39,13 @@ var ErrRoomNoExists = errors.New("Room does not exist") func BuildEvent( ctx context.Context, builder *gomatrixserverlib.EventBuilder, cfg *config.Dendrite, evTime time.Time, - queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse, + rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.Event, error) { if queryRes == nil { queryRes = &api.QueryLatestEventsAndStateResponse{} } - err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes) + err := AddPrevEventsToEvent(ctx, builder, rsAPI, queryRes) if err != nil { // This can pass through a ErrRoomNoExists to the caller return nil, err @@ -66,7 +66,7 @@ func BuildEvent( func AddPrevEventsToEvent( ctx context.Context, builder *gomatrixserverlib.EventBuilder, - queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse, + rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) error { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) if err != nil { @@ -82,8 +82,8 @@ func AddPrevEventsToEvent( RoomID: builder.RoomID, StateToFetch: eventsNeeded.Tuples(), } - if err = queryAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes); err != nil { - return fmt.Errorf("queryAPI.QueryLatestEventsAndState: %w", err) + if err = rsAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes); err != nil { + return fmt.Errorf("rsAPI.QueryLatestEventsAndState: %w", err) } if !queryRes.RoomExists { diff --git a/common/types.go b/internal/eventutil/types.go similarity index 96% rename from common/types.go rename to internal/eventutil/types.go index 91765be00..6d119ce6d 100644 --- a/common/types.go +++ b/internal/eventutil/types.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package eventutil import ( "errors" diff --git a/common/http/http.go b/internal/httputil/http.go similarity index 52% rename from common/http/http.go rename to internal/httputil/http.go index 3c6475443..9197371aa 100644 --- a/common/http/http.go +++ b/internal/httputil/http.go @@ -1,4 +1,18 @@ -package http +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil import ( "bytes" @@ -6,6 +20,8 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" + "strings" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" @@ -21,6 +37,14 @@ func PostJSON( return err } + parsedAPIURL, err := url.Parse(apiURL) + if err != nil { + return err + } + + parsedAPIURL.Path = InternalPathPrefix + strings.TrimLeft(parsedAPIURL.Path, "/") + apiURL = parsedAPIURL.String() + req, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewReader(jsonBytes)) if err != nil { return err @@ -48,10 +72,10 @@ func PostJSON( var errorBody struct { Message string `json:"message"` } - if err = json.NewDecoder(res.Body).Decode(&errorBody); err != nil { - return err + if msgerr := json.NewDecoder(res.Body).Decode(&errorBody); msgerr == nil { + return fmt.Errorf("Internal API: %d from %s: %s", res.StatusCode, apiURL, errorBody.Message) } - return fmt.Errorf("api: %d: %s", res.StatusCode, errorBody.Message) + return fmt.Errorf("Internal API: %d from %s", res.StatusCode, apiURL) } return json.NewDecoder(res.Body).Decode(response) } diff --git a/common/httpapi.go b/internal/httputil/httpapi.go similarity index 74% rename from common/httpapi.go rename to internal/httputil/httpapi.go index e5324bd17..d371d1728 100644 --- a/common/httpapi.go +++ b/internal/httputil/httpapi.go @@ -1,17 +1,35 @@ -package common +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil import ( + "context" "io" "net/http" "net/http/httptest" "net/http/httputil" "os" "strings" + "sync" "time" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common/config" + federationsenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" opentracing "github.com/opentracing/opentracing-go" @@ -30,11 +48,11 @@ type BasicAuth struct { // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( - metricsName string, data auth.Data, - f func(*http.Request, *authtypes.Device) util.JSONResponse, + metricsName string, userAPI userapi.UserInternalAPI, + f func(*http.Request, *userapi.Device) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { - device, err := auth.VerifyUserFromRequest(req, data) + device, err := auth.VerifyUserFromRequest(req, userAPI) if err != nil { return *err } @@ -167,8 +185,9 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse func MakeFedAPI( metricsName string, serverName gomatrixserverlib.ServerName, - keyRing gomatrixserverlib.KeyRing, - f func(*http.Request, *gomatrixserverlib.FederationRequest) util.JSONResponse, + keyRing gomatrixserverlib.JSONVerifier, + wakeup *FederationWakeups, + f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( @@ -177,18 +196,52 @@ func MakeFedAPI( if fedReq == nil { return errResp } - return f(req, fedReq) + go wakeup.Wakeup(req.Context(), fedReq.Origin()) + vars, err := URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + + return f(req, fedReq, vars) } return MakeExternalAPI(metricsName, h) } -// SetupHTTPAPI registers an HTTP API mux under /api and sets up a metrics -// listener. -func SetupHTTPAPI(servMux *http.ServeMux, apiMux http.Handler, cfg *config.Dendrite) { +type FederationWakeups struct { + FsAPI federationsenderAPI.FederationSenderInternalAPI + origins sync.Map +} + +func (f *FederationWakeups) Wakeup(ctx context.Context, origin gomatrixserverlib.ServerName) { + key, keyok := f.origins.Load(origin) + if keyok { + lastTime, ok := key.(time.Time) + if ok && time.Since(lastTime) < time.Minute { + return + } + } + aliveReq := federationsenderAPI.PerformServersAliveRequest{ + Servers: []gomatrixserverlib.ServerName{origin}, + } + aliveRes := federationsenderAPI.PerformServersAliveResponse{} + if err := f.FsAPI.PerformServersAlive(ctx, &aliveReq, &aliveRes); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "origin": origin, + }).Warn("incoming federation request failed to notify server alive") + } else { + f.origins.Store(origin, time.Now()) + } +} + +// SetupHTTPAPI registers an HTTP API mux under /api and sets up a metrics listener +func SetupHTTPAPI(servMux, publicApiMux, internalApiMux *mux.Router, cfg *config.Dendrite, enableHTTPAPIs bool) { if cfg.Metrics.Enabled { servMux.Handle("/metrics", WrapHandlerInBasicAuth(promhttp.Handler(), cfg.Metrics.BasicAuth)) } - servMux.Handle("/api/", http.StripPrefix("/api", apiMux)) + if enableHTTPAPIs { + servMux.Handle(InternalPathPrefix, internalApiMux) + } + servMux.Handle(PublicPathPrefix, WrapHandlerInCORS(publicApiMux)) } // WrapHandlerInBasicAuth adds basic auth to a handler. Only used for /metrics diff --git a/common/httpapi_test.go b/internal/httputil/httpapi_test.go similarity index 75% rename from common/httpapi_test.go rename to internal/httputil/httpapi_test.go index 7de7ce33c..de6ccf9b4 100644 --- a/common/httpapi_test.go +++ b/internal/httputil/httpapi_test.go @@ -1,4 +1,18 @@ -package common +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil import ( "net/http" diff --git a/internal/httputil/paths.go b/internal/httputil/paths.go new file mode 100644 index 000000000..728b5a871 --- /dev/null +++ b/internal/httputil/paths.go @@ -0,0 +1,20 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil + +const ( + PublicPathPrefix = "/_matrix/" + InternalPathPrefix = "/api/" +) diff --git a/common/routing.go b/internal/httputil/routing.go similarity index 90% rename from common/routing.go rename to internal/httputil/routing.go index 330912cde..0bd3655ec 100644 --- a/common/routing.go +++ b/internal/httputil/routing.go @@ -1,4 +1,4 @@ -// Copyright 2019 The Matrix.org Foundation C.I.C. +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package httputil import ( "net/url" @@ -24,7 +24,7 @@ import ( func URLDecodeMapValues(vmap map[string]string) (map[string]string, error) { decoded := make(map[string]string, len(vmap)) for key, value := range vmap { - decodedVal, err := url.QueryUnescape(value) + decodedVal, err := url.PathUnescape(value) if err != nil { return make(map[string]string), err } diff --git a/common/log.go b/internal/log.go similarity index 91% rename from common/log.go rename to internal/log.go index 11339ada4..fd2b84ab9 100644 --- a/common/log.go +++ b/internal/log.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package internal import ( "context" "fmt" "io" + "net/http" "os" "path" "path/filepath" @@ -26,7 +27,7 @@ import ( "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dugong" "github.com/sirupsen/logrus" ) @@ -79,6 +80,17 @@ func callerPrettyfier(f *runtime.Frame) (string, string) { return funcname, filename } +// SetupPprof starts a pprof listener. We use the DefaultServeMux here because it is +// simplest, and it gives us the freedom to run pprof on a separate port. +func SetupPprof() { + if hostPort := os.Getenv("PPROFLISTEN"); hostPort != "" { + logrus.Warn("Starting pprof on ", hostPort) + go func() { + logrus.WithError(http.ListenAndServe(hostPort, nil)).Error("Failed to setup pprof listener") + }() + } +} + // SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded. func SetupStdLogging() { logrus.SetReportCaller(true) diff --git a/common/basecomponent/base.go b/internal/setup/base.go similarity index 53% rename from common/basecomponent/base.go rename to internal/setup/base.go index 68a77cf99..66424a609 100644 --- a/common/basecomponent/base.go +++ b/internal/setup/base.go @@ -1,4 +1,4 @@ -// Copyright 2017 New Vector Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,36 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. -package basecomponent +package setup import ( "database/sql" + "fmt" "io" "net/http" "net/url" "time" - "golang.org/x/crypto/ed25519" - - "github.com/matrix-org/dendrite/common/caching" - "github.com/matrix-org/dendrite/common/keydb" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/naffka" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/Shopify/sarama" "github.com/gorilla/mux" - sarama "gopkg.in/Shopify/sarama.v1" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/common/config" + asinthttp "github.com/matrix-org/dendrite/appservice/inthttp" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" + eduinthttp "github.com/matrix-org/dendrite/eduserver/inthttp" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + fsinthttp "github.com/matrix-org/dendrite/federationsender/inthttp" + "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + rsinthttp "github.com/matrix-org/dendrite/roomserver/inthttp" + serverKeyAPI "github.com/matrix-org/dendrite/serverkeyapi/api" + skinthttp "github.com/matrix-org/dendrite/serverkeyapi/inthttp" + userapi "github.com/matrix-org/dendrite/userapi/api" + userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/sirupsen/logrus" + + _ "net/http/pprof" ) // BaseDendrite is a base for creating new instances of dendrite. It parses @@ -53,11 +62,14 @@ type BaseDendrite struct { componentName string tracerCloser io.Closer - // APIMux should be used to register new public matrix api endpoints - APIMux *mux.Router + // PublicAPIMux should be used to register new public matrix api endpoints + PublicAPIMux *mux.Router + InternalAPIMux *mux.Router + BaseMux *mux.Router // base router which created public/internal subrouters + UseHTTPAPIs bool httpClient *http.Client Cfg *config.Dendrite - ImmutableCache caching.ImmutableCache + Caches *caching.Caches KafkaConsumer sarama.Consumer KafkaProducer sarama.SyncProducer } @@ -68,9 +80,10 @@ const HTTPClientTimeout = time.Second * 30 // NewBaseDendrite creates a new instance to be used by a component. // The componentName is used for logging purposes, and should be a friendly name // of the compontent running, e.g. "SyncAPI" -func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite { - common.SetupStdLogging() - common.SetupHookLogging(cfg.Logging, componentName) +func NewBaseDendrite(cfg *config.Dendrite, componentName string, useHTTPAPIs bool) *BaseDendrite { + internal.SetupStdLogging() + internal.SetupHookLogging(cfg.Logging, componentName) + internal.SetupPprof() closer, err := cfg.SetupTracing("Dendrite" + componentName) if err != nil { @@ -85,18 +98,42 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite { kafkaConsumer, kafkaProducer = setupKafka(cfg) } - cache, err := caching.NewImmutableInMemoryLRUCache() + cache, err := caching.NewInMemoryLRUCache(true) if err != nil { logrus.WithError(err).Warnf("Failed to create cache") } + client := http.Client{Timeout: HTTPClientTimeout} + if cfg.Proxy != nil { + client.Transport = &http.Transport{Proxy: http.ProxyURL(&url.URL{ + Scheme: cfg.Proxy.Protocol, + Host: fmt.Sprintf("%s:%d", cfg.Proxy.Host, cfg.Proxy.Port), + })} + } + + // Ideally we would only use SkipClean on routes which we know can allow '/' but due to + // https://github.com/gorilla/mux/issues/460 we have to attach this at the top router. + // When used in conjunction with UseEncodedPath() we get the behaviour we want when parsing + // path parameters: + // /foo/bar%2Fbaz == [foo, bar%2Fbaz] (from UseEncodedPath) + // /foo/bar%2F%2Fbaz == [foo, bar%2F%2Fbaz] (from SkipClean) + // In particular, rooms v3 event IDs are not urlsafe and can include '/' and because they + // are randomly generated it results in flakey tests. + // We need to be careful with media APIs if they read from a filesystem to make sure they + // are not inadvertently reading paths without cleaning, else this could introduce a + // directory traversal attack e.g /../../../etc/passwd + httpmux := mux.NewRouter().SkipClean(true) + return &BaseDendrite{ componentName: componentName, + UseHTTPAPIs: useHTTPAPIs, tracerCloser: closer, Cfg: cfg, - ImmutableCache: cache, - APIMux: mux.NewRouter().UseEncodedPath(), - httpClient: &http.Client{Timeout: HTTPClientTimeout}, + Caches: cache, + BaseMux: httpmux, + PublicAPIMux: httpmux.PathPrefix(httputil.PublicPathPrefix).Subrouter().UseEncodedPath(), + InternalAPIMux: httpmux.PathPrefix(httputil.InternalPathPrefix).Subrouter().UseEncodedPath(), + httpClient: &client, KafkaConsumer: kafkaConsumer, KafkaProducer: kafkaProducer, } @@ -107,54 +144,61 @@ func (b *BaseDendrite) Close() error { return b.tracerCloser.Close() } -// CreateHTTPAppServiceAPIs returns the QueryAPI for hitting the appservice -// component over HTTP. -func (b *BaseDendrite) CreateHTTPAppServiceAPIs() appserviceAPI.AppServiceQueryAPI { - a, err := appserviceAPI.NewAppServiceQueryAPIHTTP(b.Cfg.AppServiceURL(), b.httpClient) +// AppserviceHTTPClient returns the AppServiceQueryAPI for hitting the appservice component over HTTP. +func (b *BaseDendrite) AppserviceHTTPClient() appserviceAPI.AppServiceQueryAPI { + a, err := asinthttp.NewAppserviceClient(b.Cfg.AppServiceURL(), b.httpClient) if err != nil { logrus.WithError(err).Panic("CreateHTTPAppServiceAPIs failed") } return a } -// CreateHTTPRoomserverAPIs returns the AliasAPI, InputAPI and QueryAPI for hitting -// the roomserver over HTTP. -func (b *BaseDendrite) CreateHTTPRoomserverAPIs() ( - roomserverAPI.RoomserverAliasAPI, - roomserverAPI.RoomserverInputAPI, - roomserverAPI.RoomserverQueryAPI, -) { - alias, err := roomserverAPI.NewRoomserverAliasAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient) +// RoomserverHTTPClient returns RoomserverInternalAPI for hitting the roomserver over HTTP. +func (b *BaseDendrite) RoomserverHTTPClient() roomserverAPI.RoomserverInternalAPI { + rsAPI, err := rsinthttp.NewRoomserverClient(b.Cfg.RoomServerURL(), b.httpClient, b.Caches) if err != nil { - logrus.WithError(err).Panic("NewRoomserverAliasAPIHTTP failed") + logrus.WithError(err).Panic("RoomserverHTTPClient failed", b.httpClient) } - input, err := roomserverAPI.NewRoomserverInputAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient) - if err != nil { - logrus.WithError(err).Panic("NewRoomserverInputAPIHTTP failed", b.httpClient) - } - query, err := roomserverAPI.NewRoomserverQueryAPIHTTP(b.Cfg.RoomServerURL(), b.httpClient, b.ImmutableCache) - if err != nil { - logrus.WithError(err).Panic("NewRoomserverQueryAPIHTTP failed", b.httpClient) - } - return alias, input, query + return rsAPI } -// CreateHTTPEDUServerAPIs returns eduInputAPI for hitting the EDU -// server over HTTP -func (b *BaseDendrite) CreateHTTPEDUServerAPIs() eduServerAPI.EDUServerInputAPI { - e, err := eduServerAPI.NewEDUServerInputAPIHTTP(b.Cfg.EDUServerURL(), b.httpClient) +// UserAPIClient returns UserInternalAPI for hitting the userapi over HTTP. +func (b *BaseDendrite) UserAPIClient() userapi.UserInternalAPI { + userAPI, err := userapiinthttp.NewUserAPIClient(b.Cfg.UserAPIURL(), b.httpClient) if err != nil { - logrus.WithError(err).Panic("NewEDUServerInputAPIHTTP failed", b.httpClient) + logrus.WithError(err).Panic("UserAPIClient failed", b.httpClient) + } + return userAPI +} + +// EDUServerClient returns EDUServerInputAPI for hitting the EDU server over HTTP +func (b *BaseDendrite) EDUServerClient() eduServerAPI.EDUServerInputAPI { + e, err := eduinthttp.NewEDUServerClient(b.Cfg.EDUServerURL(), b.httpClient) + if err != nil { + logrus.WithError(err).Panic("EDUServerClient failed", b.httpClient) } return e } -// CreateHTTPFederationSenderAPIs returns FederationSenderQueryAPI for hitting +// FederationSenderHTTPClient returns FederationSenderInternalAPI for hitting // the federation sender over HTTP -func (b *BaseDendrite) CreateHTTPFederationSenderAPIs() federationSenderAPI.FederationSenderQueryAPI { - f, err := federationSenderAPI.NewFederationSenderQueryAPIHTTP(b.Cfg.FederationSenderURL(), b.httpClient) +func (b *BaseDendrite) FederationSenderHTTPClient() federationSenderAPI.FederationSenderInternalAPI { + f, err := fsinthttp.NewFederationSenderClient(b.Cfg.FederationSenderURL(), b.httpClient) if err != nil { - logrus.WithError(err).Panic("NewFederationSenderQueryAPIHTTP failed", b.httpClient) + logrus.WithError(err).Panic("FederationSenderHTTPClient failed", b.httpClient) + } + return f +} + +// ServerKeyAPIClient returns ServerKeyInternalAPI for hitting the server key API over HTTP +func (b *BaseDendrite) ServerKeyAPIClient() serverKeyAPI.ServerKeyInternalAPI { + f, err := skinthttp.NewServerKeyClient( + b.Cfg.ServerKeyAPIURL(), + b.httpClient, + b.Caches, + ) + if err != nil { + logrus.WithError(err).Panic("NewServerKeyInternalAPIHTTP failed", b.httpClient) } return f } @@ -162,7 +206,7 @@ func (b *BaseDendrite) CreateHTTPFederationSenderAPIs() federationSenderAPI.Fede // CreateDeviceDB creates a new instance of the device database. Should only be // called once per component. func (b *BaseDendrite) CreateDeviceDB() devices.Database { - db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName) + db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.DbProperties(), b.Cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to devices db") } @@ -173,7 +217,7 @@ func (b *BaseDendrite) CreateDeviceDB() devices.Database { // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. func (b *BaseDendrite) CreateAccountsDB() accounts.Database { - db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName) + db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.DbProperties(), b.Cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to accounts db") } @@ -181,22 +225,6 @@ func (b *BaseDendrite) CreateAccountsDB() accounts.Database { return db } -// CreateKeyDB creates a new instance of the key database. Should only be called -// once per component. -func (b *BaseDendrite) CreateKeyDB() keydb.Database { - db, err := keydb.NewDatabase( - string(b.Cfg.Database.ServerKey), - b.Cfg.Matrix.ServerName, - b.Cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey), - b.Cfg.Matrix.KeyID, - ) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to keys db") - } - - return db -} - // CreateFederationClient creates a new federation client. Should only be called // once per component. func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { @@ -222,7 +250,14 @@ func (b *BaseDendrite) SetupAndServeHTTP(bindaddr string, listenaddr string) { WriteTimeout: HTTPServerTimeout, } - common.SetupHTTPAPI(http.DefaultServeMux, common.WrapHandlerInCORS(b.APIMux), b.Cfg) + httputil.SetupHTTPAPI( + b.BaseMux, + b.PublicAPIMux, + b.InternalAPIMux, + b.Cfg, + b.UseHTTPAPIs, + ) + serv.Handler = b.BaseMux logrus.Infof("Starting %s server on %s", b.componentName, serv.Addr) err := serv.ListenAndServe() @@ -256,7 +291,12 @@ func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) { uri, err := url.Parse(string(cfg.Database.Naffka)) if err != nil || uri.Scheme == "file" { - db, err = sqlutil.Open(common.SQLiteDriverName(), string(cfg.Database.Naffka)) + var cs string + cs, err = sqlutil.ParseFileURI(string(cfg.Database.Naffka)) + if err != nil { + logrus.WithError(err).Panic("Failed to parse naffka database file URI") + } + db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil) if err != nil { logrus.WithError(err).Panic("Failed to open naffka database") } @@ -266,7 +306,7 @@ func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) { logrus.WithError(err).Panic("Failed to setup naffka database") } } else { - db, err = sqlutil.Open("postgres", string(cfg.Database.Naffka)) + db, err = sqlutil.Open("postgres", string(cfg.Database.Naffka), nil) if err != nil { logrus.WithError(err).Panic("Failed to open naffka database") } diff --git a/common/basecomponent/flags.go b/internal/setup/flags.go similarity index 60% rename from common/basecomponent/flags.go rename to internal/setup/flags.go index 6dcb5601a..e4fc58d60 100644 --- a/common/basecomponent/flags.go +++ b/internal/setup/flags.go @@ -1,4 +1,4 @@ -// Copyright 2017 New Vector Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package basecomponent +package setup import ( "flag" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/sirupsen/logrus" ) @@ -25,33 +25,14 @@ import ( var configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.") // ParseFlags parses the commandline flags and uses them to create a config. -// If running as a monolith use `ParseMonolithFlags` instead. -func ParseFlags() *config.Dendrite { +func ParseFlags(monolith bool) *config.Dendrite { flag.Parse() if *configPath == "" { logrus.Fatal("--config must be supplied") } - cfg, err := config.Load(*configPath) - - if err != nil { - logrus.Fatalf("Invalid config file: %s", err) - } - - return cfg -} - -// ParseMonolithFlags parses the commandline flags and uses them to create a -// config. Should only be used if running a monolith. See `ParseFlags`. -func ParseMonolithFlags() *config.Dendrite { - flag.Parse() - - if *configPath == "" { - logrus.Fatal("--config must be supplied") - } - - cfg, err := config.LoadMonolithic(*configPath) + cfg, err := config.Load(*configPath, monolith) if err != nil { logrus.Fatalf("Invalid config file: %s", err) diff --git a/internal/setup/monolith.go b/internal/setup/monolith.go new file mode 100644 index 000000000..24bee9502 --- /dev/null +++ b/internal/setup/monolith.go @@ -0,0 +1,92 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package setup + +import ( + "github.com/Shopify/sarama" + "github.com/gorilla/mux" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/clientapi" + eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/federationapi" + federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/transactions" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/mediaapi" + "github.com/matrix-org/dendrite/publicroomsapi" + "github.com/matrix-org/dendrite/publicroomsapi/storage" + "github.com/matrix-org/dendrite/publicroomsapi/types" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + serverKeyAPI "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/dendrite/syncapi" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/gomatrixserverlib" +) + +// Monolith represents an instantiation of all dependencies required to build +// all components of Dendrite, for use in monolith mode. +type Monolith struct { + Config *config.Dendrite + DeviceDB devices.Database + AccountDB accounts.Database + KeyRing *gomatrixserverlib.KeyRing + Client *gomatrixserverlib.Client + FedClient *gomatrixserverlib.FederationClient + KafkaConsumer sarama.Consumer + KafkaProducer sarama.SyncProducer + + AppserviceAPI appserviceAPI.AppServiceQueryAPI + EDUInternalAPI eduServerAPI.EDUServerInputAPI + FederationSenderAPI federationSenderAPI.FederationSenderInternalAPI + RoomserverAPI roomserverAPI.RoomserverInternalAPI + ServerKeyAPI serverKeyAPI.ServerKeyInternalAPI + UserAPI userapi.UserInternalAPI + + // TODO: can we remove this? It's weird that we are required the database + // yet every other component can do that on its own. libp2p-demo uses a custom + // database though annoyingly. + PublicRoomsDB storage.Database + + // Optional + ExtPublicRoomsProvider types.ExternalPublicRoomsProvider +} + +// AddAllPublicRoutes attaches all public paths to the given router +func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) { + clientapi.AddPublicRoutes( + publicMux, m.Config, m.KafkaConsumer, m.KafkaProducer, m.DeviceDB, m.AccountDB, + m.FedClient, m.RoomserverAPI, + m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), + m.FederationSenderAPI, m.UserAPI, + ) + + keyserver.AddPublicRoutes(publicMux, m.Config, m.UserAPI) + federationapi.AddPublicRoutes( + publicMux, m.Config, m.UserAPI, m.FedClient, + m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI, + m.EDUInternalAPI, + ) + mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client) + publicroomsapi.AddPublicRoutes( + publicMux, m.Config, m.KafkaConsumer, m.UserAPI, m.PublicRoomsDB, m.RoomserverAPI, m.FedClient, + m.ExtPublicRoomsProvider, + ) + syncapi.AddPublicRoutes( + publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI, m.FedClient, m.Config, + ) +} diff --git a/common/partition_offset_table.go b/internal/sqlutil/partition_offset_table.go similarity index 88% rename from common/partition_offset_table.go rename to internal/sqlutil/partition_offset_table.go index aa799f8a0..348829025 100644 --- a/common/partition_offset_table.go +++ b/internal/sqlutil/partition_offset_table.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,14 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package sqlutil import ( "context" "database/sql" "strings" + + "github.com/matrix-org/util" ) +// A PartitionOffset is the offset into a partition of the input log. +type PartitionOffset struct { + // The ID of the partition. + Partition int32 + // The offset into the partition. + Offset int64 +} + const partitionOffsetsSchema = ` -- The offsets that the server has processed up to. CREATE TABLE IF NOT EXISTS ${prefix}_partition_offsets ( @@ -90,7 +100,12 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets( if err != nil { return nil, err } - defer CloseAndLogIfError(ctx, rows, "selectPartitionOffsets: rows.close() failed") + defer func() { + err2 := rows.Close() + if err2 != nil { + util.GetLogger(ctx).WithError(err2).Error("selectPartitionOffsets: rows.close() failed") + } + }() var results []PartitionOffset for rows.Next() { var offset PartitionOffset diff --git a/common/postgres.go b/internal/sqlutil/postgres.go similarity index 98% rename from common/postgres.go rename to internal/sqlutil/postgres.go index f8daf5783..41a5508a1 100644 --- a/common/postgres.go +++ b/internal/sqlutil/postgres.go @@ -14,7 +14,7 @@ // +build !wasm -package common +package sqlutil import "github.com/lib/pq" diff --git a/common/postgres_wasm.go b/internal/sqlutil/postgres_wasm.go similarity index 97% rename from common/postgres_wasm.go rename to internal/sqlutil/postgres_wasm.go index dcc07b31d..c45842f0c 100644 --- a/common/postgres_wasm.go +++ b/internal/sqlutil/postgres_wasm.go @@ -14,7 +14,7 @@ // +build wasm -package common +package sqlutil // IsUniqueConstraintViolationErr no-ops for this architecture func IsUniqueConstraintViolationErr(err error) bool { diff --git a/common/sql.go b/internal/sqlutil/sql.go similarity index 57% rename from common/sql.go rename to internal/sqlutil/sql.go index f50a58969..a25a4a5b6 100644 --- a/common/sql.go +++ b/internal/sqlutil/sql.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,14 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package sqlutil import ( "database/sql" + "errors" "fmt" "runtime" + "time" + + "go.uber.org/atomic" ) +// ErrUserExists is returned if a username already exists in the database. +var ErrUserExists = errors.New("Username already exists") + // A Transaction is something that can be committed or rolledback. type Transaction interface { // Commit the transaction @@ -99,3 +106,67 @@ func SQLiteDriverName() string { } return "sqlite3" } + +// DbProperties functions return properties used by database/sql/DB +type DbProperties interface { + MaxIdleConns() int + MaxOpenConns() int + ConnMaxLifetime() time.Duration +} + +// TransactionWriter allows queuing database writes so that you don't +// contend on database locks in, e.g. SQLite. Only one task will run +// at a time on a given TransactionWriter. +type TransactionWriter struct { + running atomic.Bool + todo chan transactionWriterTask +} + +func NewTransactionWriter() *TransactionWriter { + return &TransactionWriter{ + todo: make(chan transactionWriterTask), + } +} + +// transactionWriterTask represents a specific task. +type transactionWriterTask struct { + db *sql.DB + f func(txn *sql.Tx) error + wait chan error +} + +// Do queues a task to be run by a TransactionWriter. The function +// provided will be ran within a transaction as supplied by the +// database parameter. This will block until the task is finished. +func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error { + if w.todo == nil { + return errors.New("not initialised") + } + if !w.running.Load() { + go w.run() + } + task := transactionWriterTask{ + db: db, + f: f, + wait: make(chan error, 1), + } + w.todo <- task + return <-task.wait +} + +// run processes the tasks for a given transaction writer. Only one +// of these goroutines will run at a time. A transaction will be +// opened using the database object from the task and then this will +// be passed as a parameter to the task function. +func (w *TransactionWriter) run() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for task := range w.todo { + task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { + return task.f(txn) + }) + close(task.wait) + } +} diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go index 3d5fa7dc7..f6644d591 100644 --- a/internal/sqlutil/trace.go +++ b/internal/sqlutil/trace.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "os" + "regexp" "strings" "time" @@ -76,12 +77,27 @@ func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest [ // Open opens a database specified by its database driver name and a driver-specific data source name, // usually consisting of at least a database name and connection information. Includes tracing driver // if DENDRITE_TRACE_SQL=1 -func Open(driverName, dsn string) (*sql.DB, error) { +func Open(driverName, dsn string, dbProperties DbProperties) (*sql.DB, error) { if tracingEnabled { // install the wrapped driver driverName += "-trace" } - return sql.Open(driverName, dsn) + db, err := sql.Open(driverName, dsn) + if err != nil { + return nil, err + } + if driverName != SQLiteDriverName() && dbProperties != nil { + logrus.WithFields(logrus.Fields{ + "MaxOpenConns": dbProperties.MaxOpenConns(), + "MaxIdleConns": dbProperties.MaxIdleConns(), + "ConnMaxLifetime": dbProperties.ConnMaxLifetime(), + "dataSourceName": regexp.MustCompile(`://[^@]*@`).ReplaceAllLiteralString(dsn, "://"), + }).Debug("Setting DB connection limits") + db.SetMaxOpenConns(dbProperties.MaxOpenConns()) + db.SetMaxIdleConns(dbProperties.MaxIdleConns()) + db.SetConnMaxLifetime(dbProperties.ConnMaxLifetime()) + } + return db, nil } func init() { diff --git a/internal/sqlutil/uri.go b/internal/sqlutil/uri.go new file mode 100644 index 000000000..703258e6c --- /dev/null +++ b/internal/sqlutil/uri.go @@ -0,0 +1,38 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlutil + +import ( + "fmt" + "net/url" +) + +// ParseFileURI returns the filepath in the given file: URI. Specifically, this will handle +// both relative (file:foo.db) and absolute (file:///path/to/foo) paths. +func ParseFileURI(dataSourceName string) (string, error) { + uri, err := url.Parse(dataSourceName) + if err != nil { + return "", err + } + var cs string + if uri.Opaque != "" { // file:filename.db + cs = uri.Opaque + } else if uri.Path != "" { // file:///path/to/filename.db + cs = uri.Path + } else { + return "", fmt.Errorf("invalid file uri: %s", dataSourceName) + } + return cs, nil +} diff --git a/common/test/client.go b/internal/test/client.go similarity index 100% rename from common/test/client.go rename to internal/test/client.go diff --git a/common/test/config.go b/internal/test/config.go similarity index 98% rename from common/test/config.go rename to internal/test/config.go index f88e45125..951f65a12 100644 --- a/common/test/config.go +++ b/internal/test/config.go @@ -27,7 +27,7 @@ import ( "path/filepath" "time" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/gomatrixserverlib" "gopkg.in/yaml.v2" ) @@ -84,7 +84,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con cfg.Kafka.Topics.OutputRoomEvent = "test.room.output" cfg.Kafka.Topics.OutputClientData = "test.clientapi.output" cfg.Kafka.Topics.OutputTypingEvent = "test.typing.output" - cfg.Kafka.Topics.UserUpdates = "test.user.output" // TODO: Use different databases for the different schemas. // Using the same database for every schema currently works because diff --git a/common/test/kafka.go b/internal/test/kafka.go similarity index 100% rename from common/test/kafka.go rename to internal/test/kafka.go diff --git a/internal/test/keyring.go b/internal/test/keyring.go new file mode 100644 index 000000000..ed9c34849 --- /dev/null +++ b/internal/test/keyring.go @@ -0,0 +1,31 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +// NopJSONVerifier is a JSONVerifier that verifies nothing and returns no errors. +type NopJSONVerifier struct { + // this verifier verifies nothing +} + +func (t *NopJSONVerifier) VerifyJSONs(ctx context.Context, requests []gomatrixserverlib.VerifyJSONRequest) ([]gomatrixserverlib.VerifyJSONResult, error) { + result := make([]gomatrixserverlib.VerifyJSONResult, len(requests)) + return result, nil +} diff --git a/common/test/server.go b/internal/test/server.go similarity index 70% rename from common/test/server.go rename to internal/test/server.go index 4fdd5e638..c3348d533 100644 --- a/common/test/server.go +++ b/internal/test/server.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,13 +15,18 @@ package test import ( + "context" "fmt" + "net" + "net/http" "os" "os/exec" "path/filepath" "strings" + "sync" + "testing" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" ) // Defaulting allows assignment of string variables with a fallback default value @@ -103,3 +108,46 @@ func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) { proxyArgs, ) } + +// ListenAndServe will listen on a random high-numbered port and attach the given router. +// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed. +func ListenAndServe(t *testing.T, router http.Handler, useTLS bool) (apiURL string, cancel func()) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to listen: %s", err) + } + port := listener.Addr().(*net.TCPAddr).Port + srv := http.Server{} + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + srv.Handler = router + var err error + if useTLS { + certFile := filepath.Join(os.TempDir(), "dendrite.cert") + keyFile := filepath.Join(os.TempDir(), "dendrite.key") + err = NewTLSKey(keyFile, certFile) + if err != nil { + t.Logf("failed to generate tls key/cert: %s", err) + return + } + err = srv.ServeTLS(listener, certFile, keyFile) + } else { + err = srv.Serve(listener) + } + if err != nil && err != http.ErrServerClosed { + t.Logf("Listen failed: %s", err) + } + }() + + secure := "" + if useTLS { + secure = "s" + } + return fmt.Sprintf("http%s://localhost:%d", secure, port), func() { + _ = srv.Shutdown(context.Background()) + wg.Wait() + } +} diff --git a/common/test/slice.go b/internal/test/slice.go similarity index 100% rename from common/test/slice.go rename to internal/test/slice.go diff --git a/common/transactions/transactions.go b/internal/transactions/transactions.go similarity index 100% rename from common/transactions/transactions.go rename to internal/transactions/transactions.go diff --git a/common/transactions/transactions_test.go b/internal/transactions/transactions_test.go similarity index 100% rename from common/transactions/transactions_test.go rename to internal/transactions/transactions_test.go diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go new file mode 100644 index 000000000..bedc4dfb8 --- /dev/null +++ b/keyserver/keyserver.go @@ -0,0 +1,29 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package keyserver + +import ( + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/keyserver/routing" + userapi "github.com/matrix-org/dendrite/userapi/api" +) + +// AddPublicRoutes registers HTTP handlers for CS API calls +func AddPublicRoutes( + router *mux.Router, cfg *config.Dendrite, userAPI userapi.UserInternalAPI, +) { + routing.Setup(router, cfg, userAPI) +} diff --git a/clientapi/auth/authtypes/account.go b/keyserver/routing/keys.go similarity index 57% rename from clientapi/auth/authtypes/account.go rename to keyserver/routing/keys.go index fd3c15a84..a279a747c 100644 --- a/clientapi/auth/authtypes/account.go +++ b/keyserver/routing/keys.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,20 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package authtypes +package routing import ( - "github.com/matrix-org/gomatrixserverlib" + "net/http" + + "github.com/matrix-org/util" ) -// Account represents a Matrix account on this home server. -type Account struct { - UserID string - Localpart string - ServerName gomatrixserverlib.ServerName - Profile *Profile - AppServiceID string - // TODO: Other flags like IsAdmin, IsGuest - // TODO: Devices - // TODO: Associations (e.g. with application services) +func QueryKeys( + req *http.Request, +) util.JSONResponse { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: map[string]interface{}{ + "failures": map[string]interface{}{}, + "device_keys": map[string]interface{}{}, + }, + } } diff --git a/keyserver/routing/routing.go b/keyserver/routing/routing.go new file mode 100644 index 000000000..dba43528f --- /dev/null +++ b/keyserver/routing/routing.go @@ -0,0 +1,54 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +const pathPrefixR0 = "/client/r0" + +// Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client +// to clients which need to make outbound HTTP requests. +// +// Due to Setup being used to call many other functions, a gocyclo nolint is +// applied: +// nolint: gocyclo +func Setup( + publicAPIMux *mux.Router, cfg *config.Dendrite, userAPI userapi.UserInternalAPI, +) { + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() + + r0mux.Handle("/keys/query", + httputil.MakeAuthAPI("queryKeys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return QueryKeys(req) + }), + ).Methods(http.MethodPost, http.MethodOptions) + + r0mux.Handle("/keys/upload/{keyID}", + httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{}, + } + }), + ).Methods(http.MethodPost, http.MethodOptions) +} diff --git a/mediaapi/fileutils/fileutils.go b/mediaapi/fileutils/fileutils.go index d6badcb94..42d07e900 100644 --- a/mediaapi/fileutils/fileutils.go +++ b/mediaapi/fileutils/fileutils.go @@ -25,7 +25,7 @@ import ( "path/filepath" "strings" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/mediaapi/types" log "github.com/sirupsen/logrus" ) diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index f2e614c17..290ef46e1 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -15,26 +15,27 @@ package mediaapi import ( - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/mediaapi/routing" "github.com/matrix-org/dendrite/mediaapi/storage" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) -// SetupMediaAPIComponent sets up and registers HTTP handlers for the MediaAPI -// component. -func SetupMediaAPIComponent( - base *basecomponent.BaseDendrite, - deviceDB devices.Database, +// AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component. +func AddPublicRoutes( + router *mux.Router, cfg *config.Dendrite, + userAPI userapi.UserInternalAPI, + client *gomatrixserverlib.Client, ) { - mediaDB, err := storage.Open(string(base.Cfg.Database.MediaAPI)) + mediaDB, err := storage.Open(string(cfg.Database.MediaAPI), cfg.DbProperties()) if err != nil { logrus.WithError(err).Panicf("failed to connect to media db") } routing.Setup( - base.APIMux, base.Cfg, mediaDB, deviceDB, gomatrixserverlib.NewClient(), + router, cfg, mediaDB, userAPI, client, ) } diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 9b23556d5..b5b640242 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -21,15 +21,17 @@ import ( "io" "mime" "net/http" + "net/url" "os" "path/filepath" "regexp" "strconv" "strings" "sync" + "unicode" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" @@ -43,7 +45,11 @@ import ( const mediaIDCharacters = "A-Za-z0-9_=-" // Note: unfortunately regex.MustCompile() cannot be assigned to a const -var mediaIDRegex = regexp.MustCompile("[" + mediaIDCharacters + "]+") +var mediaIDRegex = regexp.MustCompile("^[" + mediaIDCharacters + "]+$") + +// Regular expressions to help us cope with Content-Disposition parsing +var rfc2183 = regexp.MustCompile(`filename\=utf-8\"(.*)\"`) +var rfc6266 = regexp.MustCompile(`filename\*\=utf-8\'\'(.*)`) // downloadRequest metadata included in or derivable from a download or thumbnail request // https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid @@ -53,6 +59,7 @@ type downloadRequest struct { IsThumbnailRequest bool ThumbnailSize types.ThumbnailSize Logger *log.Entry + DownloadFilename string } // Download implements GET /download and GET /thumbnail @@ -72,6 +79,7 @@ func Download( activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, isThumbnailRequest bool, + customFilename string, ) { dReq := &downloadRequest{ MediaMetadata: &types.MediaMetadata{ @@ -83,6 +91,7 @@ func Download( "Origin": origin, "MediaID": mediaID, }), + DownloadFilename: customFilename, } if dReq.IsThumbnailRequest { @@ -118,7 +127,10 @@ func Download( ) if err != nil { // TODO: Handle the fact we might have started writing the response - dReq.jsonErrorResponse(w, util.ErrorResponse(err)) + dReq.jsonErrorResponse(w, util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("Failed to download: " + err.Error()), + }) return } @@ -138,7 +150,7 @@ func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSON if err != nil { r.Logger.WithError(err).Error("Failed to marshal JSONResponse") // this should never fail to be marshalled so drop err to the floor - res = util.MessageResponse(http.StatusInternalServerError, "Internal Server Error") + res = util.MessageResponse(http.StatusNotFound, "Download request failed: "+err.Error()) resBytes, _ = json.Marshal(res.JSON) } @@ -297,9 +309,8 @@ func (r *downloadRequest) respondFromLocalFile( }).Info("Responding with file") responseFile = file responseMetadata = r.MediaMetadata - - if len(responseMetadata.UploadName) > 0 { - w.Header().Set("Content-Disposition", fmt.Sprintf(`inline; filename*=utf-8"%s"`, responseMetadata.UploadName)) + if err := r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil { + return nil, err } } @@ -318,6 +329,67 @@ func (r *downloadRequest) respondFromLocalFile( return responseMetadata, nil } +func (r *downloadRequest) addDownloadFilenameToHeaders( + w http.ResponseWriter, + responseMetadata *types.MediaMetadata, +) error { + // If the requestor supplied a filename to name the download then + // use that, otherwise use the filename from the response metadata. + filename := string(responseMetadata.UploadName) + if r.DownloadFilename != "" { + filename = r.DownloadFilename + } + + if len(filename) == 0 { + return nil + } + + unescaped, err := url.PathUnescape(filename) + if err != nil { + return fmt.Errorf("url.PathUnescape: %w", err) + } + + isASCII := true // Is the string ASCII or UTF-8? + quote := `` // Encloses the string (ASCII only) + for i := 0; i < len(unescaped); i++ { + if unescaped[i] > unicode.MaxASCII { + isASCII = false + } + if unescaped[i] == 0x20 || unescaped[i] == 0x3B { + // If the filename contains a space or a semicolon, which + // are special characters in Content-Disposition + quote = `"` + } + } + + // We don't necessarily want a full escape as the Content-Disposition + // can take many of the characters that PathEscape would otherwise and + // browser support for encoding is a bit wild, so we'll escape only + // the characters that we know will mess up the parsing of the + // Content-Disposition header elements themselves + unescaped = strings.ReplaceAll(unescaped, `\`, `\\"`) + unescaped = strings.ReplaceAll(unescaped, `"`, `\"`) + + if isASCII { + // For ASCII filenames, we should only quote the filename if + // it needs to be done, e.g. it contains a space or a character + // that would otherwise be parsed as a control character in the + // Content-Disposition header + w.Header().Set("Content-Disposition", fmt.Sprintf( + `inline; filename=%s%s%s`, + quote, unescaped, quote, + )) + } else { + // For UTF-8 filenames, we quote always, as that's the standard + w.Header().Set("Content-Disposition", fmt.Sprintf( + `inline; filename*=utf-8''%s`, + url.QueryEscape(unescaped), + )) + } + + return nil +} + // Note: Thumbnail generation may be ongoing asynchronously. // If no thumbnail was found then returns nil, nil, nil func (r *downloadRequest) getThumbnailFile( @@ -632,9 +704,22 @@ func (r *downloadRequest) fetchRemoteFile( } r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength) r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) - _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) - if err == nil && params["filename"] != "" { - r.MediaMetadata.UploadName = types.Filename(params["filename"]) + + dispositionHeader := resp.Header.Get("Content-Disposition") + if _, params, e := mime.ParseMediaType(dispositionHeader); e == nil { + if params["filename"] != "" { + r.MediaMetadata.UploadName = types.Filename(params["filename"]) + } else if params["filename*"] != "" { + r.MediaMetadata.UploadName = types.Filename(params["filename*"]) + } + } else { + if matches := rfc6266.FindStringSubmatch(dispositionHeader); len(matches) > 1 { + // Always prefer the RFC6266 UTF-8 name if possible + r.MediaMetadata.UploadName = types.Filename(matches[1]) + } else if matches := rfc2183.FindStringSubmatch(dispositionHeader); len(matches) > 1 { + // Otherwise, see if an RFC2183 name was provided (ASCII only) + r.MediaMetadata.UploadName = types.Filename(matches[1]) + } } r.Logger.Info("Transferring remote file") diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index e27e98b5f..eaccbdc62 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -16,14 +16,13 @@ package routing import ( "net/http" + "strings" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -33,7 +32,8 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) -const pathPrefixR0 = "/_matrix/media/r0" +const pathPrefixR0 = "/media/r0" +const pathPrefixV1 = "/media/v1" // TODO: remove when synapse is fixed // Setup registers the media API HTTP handlers // @@ -41,27 +41,21 @@ const pathPrefixR0 = "/_matrix/media/r0" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, + publicAPIMux *mux.Router, cfg *config.Dendrite, db storage.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, client *gomatrixserverlib.Client, ) { - r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() + v1mux := publicAPIMux.PathPrefix(pathPrefixV1).Subrouter() activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ PathToResult: map[string]*types.ThumbnailGenerationResult{}, } - authData := auth.Data{ - AccountDB: nil, - DeviceDB: deviceDB, - AppServices: nil, - } - - // TODO: Add AS support - r0mux.Handle("/upload", common.MakeAuthAPI( - "upload", authData, - func(req *http.Request, dev *authtypes.Device) util.JSONResponse { + r0mux.Handle("/upload", httputil.MakeAuthAPI( + "upload", userAPI, + func(req *http.Request, dev *userapi.Device) util.JSONResponse { return Upload(req, cfg, dev, db, activeThumbnailGeneration) }, )).Methods(http.MethodPost, http.MethodOptions) @@ -69,9 +63,13 @@ func Setup( activeRemoteRequests := &types.ActiveRemoteRequests{ MXCToResult: map[string]*types.RemoteRequestResult{}, } - r0mux.Handle("/download/{serverName}/{mediaId}", - makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration), - ).Methods(http.MethodGet, http.MethodOptions) + + downloadHandler := makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration) + r0mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) + v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed + v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed + r0mux.Handle("/thumbnail/{serverName}/{mediaId}", makeDownloadAPI("thumbnail", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration), ).Methods(http.MethodGet, http.MethodOptions) @@ -95,15 +93,28 @@ func makeDownloadAPI( httpHandler := func(w http.ResponseWriter, req *http.Request) { req = util.RequestWithLogging(req) - // Set common headers returned regardless of the outcome of the request + // Set internal headers returned regardless of the outcome of the request util.SetCORSHeaders(w) // Content-Type will be overridden in case of returning file data, else we respond with JSON-formatted errors w.Header().Set("Content-Type", "application/json") - vars, _ := common.URLDecodeMapValues(mux.Vars(req)) + + vars, _ := httputil.URLDecodeMapValues(mux.Vars(req)) + serverName := gomatrixserverlib.ServerName(vars["serverName"]) + + // For the purposes of loop avoidance, we will return a 404 if allow_remote is set to + // false in the query string and the target server name isn't our own. + // https://github.com/matrix-org/matrix-doc/pull/1265 + if allowRemote := req.URL.Query().Get("allow_remote"); strings.ToLower(allowRemote) == "false" { + if serverName != cfg.Matrix.ServerName { + w.WriteHeader(http.StatusNotFound) + return + } + } + Download( w, req, - gomatrixserverlib.ServerName(vars["serverName"]), + serverName, types.MediaID(vars["mediaId"]), cfg, db, @@ -111,6 +122,7 @@ func makeDownloadAPI( activeRemoteRequests, activeThumbnailGeneration, name == "thumbnail", + vars["downloadName"], ) } return promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 1da551646..0aa335d53 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -16,6 +16,7 @@ package routing import ( "context" + "encoding/base64" "fmt" "io" "net/http" @@ -26,7 +27,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" @@ -125,7 +126,9 @@ func (r *uploadRequest) doUpload( r.MediaMetadata.FileSizeBytes = bytesWritten r.MediaMetadata.Base64Hash = hash - r.MediaMetadata.MediaID = types.MediaID(hash) + r.MediaMetadata.MediaID = types.MediaID(base64.RawURLEncoding.EncodeToString( + []byte(string(r.MediaMetadata.UploadName) + string(r.MediaMetadata.Base64Hash)), + )) r.Logger = r.Logger.WithField("MediaID", r.MediaMetadata.MediaID) diff --git a/mediaapi/storage/postgres/prepare.go b/mediaapi/storage/postgres/prepare.go index 090c3d17d..a2e01884e 100644 --- a/mediaapi/storage/postgres/prepare.go +++ b/mediaapi/storage/postgres/prepare.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// FIXME: This should be made common! +// FIXME: This should be made internal! package postgres diff --git a/mediaapi/storage/postgres/storage.go b/mediaapi/storage/postgres/storage.go index 18126b151..e45e08416 100644 --- a/mediaapi/storage/postgres/storage.go +++ b/mediaapi/storage/postgres/storage.go @@ -33,10 +33,10 @@ type Database struct { } // Open opens a postgres database. -func Open(dataSourceName string) (*Database, error) { +func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { var d Database var err error - if d.db, err = sqlutil.Open("postgres", dataSourceName); err != nil { + if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } if err = d.statements.prepare(d.db); err != nil { diff --git a/mediaapi/storage/postgres/thumbnail_table.go b/mediaapi/storage/postgres/thumbnail_table.go index 08bddc36f..3f28cdbbf 100644 --- a/mediaapi/storage/postgres/thumbnail_table.go +++ b/mediaapi/storage/postgres/thumbnail_table.go @@ -20,8 +20,7 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/common" - + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -146,7 +145,7 @@ func (s *thumbnailStatements) selectThumbnails( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectThumbnails: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectThumbnails: rows.close() failed") var thumbnails []*types.ThumbnailMetadata for rows.Next() { diff --git a/mediaapi/storage/sqlite3/prepare.go b/mediaapi/storage/sqlite3/prepare.go index a6bc24c98..8fb3b56f3 100644 --- a/mediaapi/storage/sqlite3/prepare.go +++ b/mediaapi/storage/sqlite3/prepare.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// FIXME: This should be made common! +// FIXME: This should be made internal! package sqlite3 diff --git a/mediaapi/storage/sqlite3/storage.go b/mediaapi/storage/sqlite3/storage.go index abafecf20..010c0a66e 100644 --- a/mediaapi/storage/sqlite3/storage.go +++ b/mediaapi/storage/sqlite3/storage.go @@ -20,7 +20,6 @@ import ( "database/sql" // Import the postgres database driver. - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -37,7 +36,11 @@ type Database struct { func Open(dataSourceName string) (*Database, error) { var d Database var err error - if d.db, err = sqlutil.Open(common.SQLiteDriverName(), dataSourceName); err != nil { + cs, err := sqlutil.ParseFileURI(dataSourceName) + if err != nil { + return nil, err + } + if d.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } if err = d.statements.prepare(d.db); err != nil { diff --git a/mediaapi/storage/sqlite3/thumbnail_table.go b/mediaapi/storage/sqlite3/thumbnail_table.go index 280fafe8d..432a1590c 100644 --- a/mediaapi/storage/sqlite3/thumbnail_table.go +++ b/mediaapi/storage/sqlite3/thumbnail_table.go @@ -20,8 +20,7 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/common" - + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -136,7 +135,7 @@ func (s *thumbnailStatements) selectThumbnails( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectThumbnails: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectThumbnails: rows.close() failed") var thumbnails []*types.ThumbnailMetadata for rows.Next() { diff --git a/mediaapi/storage/storage.go b/mediaapi/storage/storage.go index c533477cd..5ff114db6 100644 --- a/mediaapi/storage/storage.go +++ b/mediaapi/storage/storage.go @@ -19,22 +19,23 @@ package storage import ( "net/url" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/postgres" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" ) // Open opens a postgres database. -func Open(dataSourceName string) (Database, error) { +func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.Open(dataSourceName) + return postgres.Open(dataSourceName, dbProperties) } switch uri.Scheme { case "postgres": - return postgres.Open(dataSourceName) + return postgres.Open(dataSourceName, dbProperties) case "file": return sqlite3.Open(dataSourceName) default: - return postgres.Open(dataSourceName) + return postgres.Open(dataSourceName, dbProperties) } } diff --git a/mediaapi/storage/storage_wasm.go b/mediaapi/storage/storage_wasm.go index 92f0ad134..a672271f9 100644 --- a/mediaapi/storage/storage_wasm.go +++ b/mediaapi/storage/storage_wasm.go @@ -18,11 +18,15 @@ import ( "fmt" "net/url" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" ) // Open opens a postgres database. -func Open(dataSourceName string) (Database, error) { +func Open( + dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam +) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return nil, fmt.Errorf("Cannot use postgres implementation") diff --git a/mediaapi/thumbnailer/thumbnailer.go b/mediaapi/thumbnailer/thumbnailer.go index ebf5138c5..9a58b5bc1 100644 --- a/mediaapi/thumbnailer/thumbnailer.go +++ b/mediaapi/thumbnailer/thumbnailer.go @@ -22,7 +22,7 @@ import ( "path/filepath" "sync" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" log "github.com/sirupsen/logrus" diff --git a/mediaapi/thumbnailer/thumbnailer_bimg.go b/mediaapi/thumbnailer/thumbnailer_bimg.go index db6f23ace..915d576e3 100644 --- a/mediaapi/thumbnailer/thumbnailer_bimg.go +++ b/mediaapi/thumbnailer/thumbnailer_bimg.go @@ -21,7 +21,7 @@ import ( "os" "time" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" log "github.com/sirupsen/logrus" diff --git a/mediaapi/thumbnailer/thumbnailer_nfnt.go b/mediaapi/thumbnailer/thumbnailer_nfnt.go index 4f1e98aa0..b48551e4e 100644 --- a/mediaapi/thumbnailer/thumbnailer_nfnt.go +++ b/mediaapi/thumbnailer/thumbnailer_nfnt.go @@ -30,7 +30,7 @@ import ( "os" "time" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/nfnt/resize" diff --git a/mediaapi/types/types.go b/mediaapi/types/types.go index 855e8fe27..9fa549509 100644 --- a/mediaapi/types/types.go +++ b/mediaapi/types/types.go @@ -17,7 +17,7 @@ package types import ( "sync" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/publicroomsapi/consumers/roomserver.go b/publicroomsapi/consumers/roomserver.go index 2bbd92b72..b9686d56d 100644 --- a/publicroomsapi/consumers/roomserver.go +++ b/publicroomsapi/consumers/roomserver.go @@ -18,20 +18,20 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" - sarama "gopkg.in/Shopify/sarama.v1" ) // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { - roomServerConsumer *common.ContinualConsumer - db storage.Database - query api.RoomserverQueryAPI + rsAPI api.RoomserverInternalAPI + rsConsumer *internal.ContinualConsumer + db storage.Database } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -39,17 +39,17 @@ func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, store storage.Database, - queryAPI api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, ) *OutputRoomEventConsumer { - consumer := common.ContinualConsumer{ + consumer := internal.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputRoomEvent), Consumer: kafkaConsumer, PartitionStore: store, } s := &OutputRoomEventConsumer{ - roomServerConsumer: &consumer, - db: store, - query: queryAPI, + rsConsumer: &consumer, + db: store, + rsAPI: rsAPI, } consumer.ProcessMessage = s.onMessage @@ -58,7 +58,7 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - return s.roomServerConsumer.Start() + return s.rsConsumer.Start() } // onMessage is called when the sync server receives a new event from the room server output log. @@ -78,29 +78,17 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { return nil } - ev := output.NewRoomEvent.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "type": ev.Type(), - }).Info("received event from roomserver") - - addQueryReq := api.QueryEventsByIDRequest{EventIDs: output.NewRoomEvent.AddsStateEventIDs} - var addQueryRes api.QueryEventsByIDResponse - if err := s.query.QueryEventsByID(context.TODO(), &addQueryReq, &addQueryRes); err != nil { - log.Warn(err) - return err - } - - remQueryReq := api.QueryEventsByIDRequest{EventIDs: output.NewRoomEvent.RemovesStateEventIDs} var remQueryRes api.QueryEventsByIDResponse - if err := s.query.QueryEventsByID(context.TODO(), &remQueryReq, &remQueryRes); err != nil { - log.Warn(err) - return err + if len(output.NewRoomEvent.RemovesStateEventIDs) > 0 { + remQueryReq := api.QueryEventsByIDRequest{EventIDs: output.NewRoomEvent.RemovesStateEventIDs} + if err := s.rsAPI.QueryEventsByID(context.TODO(), &remQueryReq, &remQueryRes); err != nil { + log.Warn(err) + return err + } } var addQueryEvents, remQueryEvents []gomatrixserverlib.Event - for _, headeredEvent := range addQueryRes.Events { + for _, headeredEvent := range output.NewRoomEvent.AddsState() { addQueryEvents = append(addQueryEvents, headeredEvent.Event) } for _, headeredEvent := range remQueryRes.Events { diff --git a/publicroomsapi/directory/directory.go b/publicroomsapi/directory/directory.go index 837018e64..8b68279aa 100644 --- a/publicroomsapi/directory/directory.go +++ b/publicroomsapi/directory/directory.go @@ -17,8 +17,8 @@ package directory import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -59,7 +59,7 @@ func GetVisibility( // SetVisibility implements PUT /directory/list/room/{roomID} // TODO: Allow admin users to edit the room visibility func SetVisibility( - req *http.Request, publicRoomsDatabase storage.Database, queryAPI api.RoomserverQueryAPI, dev *authtypes.Device, + req *http.Request, publicRoomsDatabase storage.Database, rsAPI api.RoomserverInternalAPI, dev *userapi.Device, roomID string, ) util.JSONResponse { queryMembershipReq := api.QueryMembershipForUserRequest{ @@ -67,7 +67,7 @@ func SetVisibility( UserID: dev.UserID, } var queryMembershipRes api.QueryMembershipForUserResponse - err := queryAPI.QueryMembershipForUser(req.Context(), &queryMembershipReq, &queryMembershipRes) + err := rsAPI.QueryMembershipForUser(req.Context(), &queryMembershipReq, &queryMembershipRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("could not query membership for user") return jsonerror.InternalServerError() @@ -87,7 +87,7 @@ func SetVisibility( }}, } var queryEventsRes api.QueryLatestEventsAndStateResponse - err = queryAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) + err = rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) if err != nil || len(queryEventsRes.StateEvents) == 0 { util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") return jsonerror.InternalServerError() diff --git a/publicroomsapi/directory/public_rooms.go b/publicroomsapi/directory/public_rooms.go index 7bd6740eb..df9df8ff9 100644 --- a/publicroomsapi/directory/public_rooms.go +++ b/publicroomsapi/directory/public_rooms.go @@ -16,7 +16,9 @@ package directory import ( "context" + "math/rand" "net/http" + "sort" "strconv" "sync" "time" @@ -96,6 +98,28 @@ func GetPostPublicRoomsWithExternal( // downcasting `limit` is safe as we know it isn't bigger than request.Limit which is int16 fedRooms := bulkFetchPublicRoomsFromServers(req.Context(), fedClient, extRoomsProvider.Homeservers(), int16(limit)) response.Chunk = append(response.Chunk, fedRooms...) + + // de-duplicate rooms with the same room ID. We can join the room via any of these aliases as we know these servers + // are alive and well, so we arbitrarily pick one (purposefully shuffling them to spread the load a bit) + var publicRooms []gomatrixserverlib.PublicRoom + haveRoomIDs := make(map[string]bool) + rand.Shuffle(len(response.Chunk), func(i, j int) { + response.Chunk[i], response.Chunk[j] = response.Chunk[j], response.Chunk[i] + }) + for _, r := range response.Chunk { + if haveRoomIDs[r.RoomID] { + continue + } + haveRoomIDs[r.RoomID] = true + publicRooms = append(publicRooms, r) + } + // sort by member count + sort.SliceStable(publicRooms, func(i, j int) bool { + return publicRooms[i].JoinedMembersCount > publicRooms[j].JoinedMembersCount + }) + + response.Chunk = publicRooms + return util.JSONResponse{ Code: http.StatusOK, JSON: response, diff --git a/publicroomsapi/publicroomsapi.go b/publicroomsapi/publicroomsapi.go index 6efb54bd9..b9baa1056 100644 --- a/publicroomsapi/publicroomsapi.go +++ b/publicroomsapi/publicroomsapi.go @@ -15,33 +15,37 @@ package publicroomsapi import ( - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/Shopify/sarama" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/publicroomsapi/consumers" "github.com/matrix-org/dendrite/publicroomsapi/routing" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/publicroomsapi/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) -// SetupPublicRoomsAPIComponent sets up and registers HTTP handlers for the PublicRoomsAPI +// AddPublicRoutes sets up and registers HTTP handlers for the PublicRoomsAPI // component. -func SetupPublicRoomsAPIComponent( - base *basecomponent.BaseDendrite, - deviceDB devices.Database, +func AddPublicRoutes( + router *mux.Router, + cfg *config.Dendrite, + consumer sarama.Consumer, + userAPI userapi.UserInternalAPI, publicRoomsDB storage.Database, - rsQueryAPI roomserverAPI.RoomserverQueryAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, fedClient *gomatrixserverlib.FederationClient, extRoomsProvider types.ExternalPublicRoomsProvider, ) { rsConsumer := consumers.NewOutputRoomEventConsumer( - base.Cfg, base.KafkaConsumer, publicRoomsDB, rsQueryAPI, + cfg, consumer, publicRoomsDB, rsAPI, ) if err := rsConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start public rooms server consumer") } - routing.Setup(base.APIMux, deviceDB, publicRoomsDB, rsQueryAPI, fedClient, extRoomsProvider) + routing.Setup(router, userAPI, publicRoomsDB, rsAPI, fedClient, extRoomsProvider) } diff --git a/publicroomsapi/routing/routing.go b/publicroomsapi/routing/routing.go index da5ea90d6..9c82d3508 100644 --- a/publicroomsapi/routing/routing.go +++ b/publicroomsapi/routing/routing.go @@ -17,13 +17,11 @@ package routing import ( "net/http" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/publicroomsapi/directory" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/publicroomsapi/types" @@ -31,7 +29,7 @@ import ( "github.com/matrix-org/util" ) -const pathPrefixR0 = "/_matrix/client/r0" +const pathPrefixR0 = "/client/r0" // Setup configures the given mux with publicroomsapi server listeners // @@ -39,20 +37,14 @@ const pathPrefixR0 = "/_matrix/client/r0" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, deviceDB devices.Database, publicRoomsDB storage.Database, queryAPI api.RoomserverQueryAPI, + publicAPIMux *mux.Router, userAPI userapi.UserInternalAPI, publicRoomsDB storage.Database, rsAPI api.RoomserverInternalAPI, fedClient *gomatrixserverlib.FederationClient, extRoomsProvider types.ExternalPublicRoomsProvider, ) { - r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() - - authData := auth.Data{ - AccountDB: nil, - DeviceDB: deviceDB, - AppServices: nil, - } + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() r0mux.Handle("/directory/list/room/{roomID}", - common.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -61,16 +53,16 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) // TODO: Add AS support r0mux.Handle("/directory/list/room/{roomID}", - common.MakeAuthAPI("directory_list", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return directory.SetVisibility(req, publicRoomsDB, queryAPI, device, vars["roomID"]) + return directory.SetVisibility(req, publicRoomsDB, rsAPI, device, vars["roomID"]) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/publicRooms", - common.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { if extRoomsProvider != nil { return directory.GetPostPublicRoomsWithExternal(req, publicRoomsDB, fedClient, extRoomsProvider) } @@ -79,8 +71,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) // Federation - TODO: should this live here or in federation API? It's sure easier if it's here so here it is. - apiMux.Handle("/_matrix/federation/v1/publicRooms", - common.MakeExternalAPI("federation_public_rooms", func(req *http.Request) util.JSONResponse { + publicAPIMux.Handle("/federation/v1/publicRooms", + httputil.MakeExternalAPI("federation_public_rooms", func(req *http.Request) util.JSONResponse { return directory.GetPostPublicRooms(req, publicRoomsDB) }), ).Methods(http.MethodGet) diff --git a/publicroomsapi/storage/interface.go b/publicroomsapi/storage/interface.go index 0feca0e20..0ca6f455c 100644 --- a/publicroomsapi/storage/interface.go +++ b/publicroomsapi/storage/interface.go @@ -17,12 +17,12 @@ package storage import ( "context" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - common.PartitionStorer + internal.PartitionStorer GetRoomVisibility(ctx context.Context, roomID string) (bool, error) SetRoomVisibility(ctx context.Context, visible bool, roomID string) error CountPublicRooms(ctx context.Context) (int64, error) diff --git a/publicroomsapi/storage/postgres/public_rooms_table.go b/publicroomsapi/storage/postgres/public_rooms_table.go index 7e31afd2a..39e355368 100644 --- a/publicroomsapi/storage/postgres/public_rooms_table.go +++ b/publicroomsapi/storage/postgres/public_rooms_table.go @@ -21,7 +21,7 @@ import ( "errors" "fmt" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/lib/pq" @@ -205,7 +205,7 @@ func (s *publicRoomsStatements) selectPublicRooms( if err != nil { return []gomatrixserverlib.PublicRoom{}, nil } - defer common.CloseAndLogIfError(ctx, rows, "selectPublicRooms: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectPublicRooms: rows.close() failed") rooms := []gomatrixserverlib.PublicRoom{} for rows.Next() { diff --git a/publicroomsapi/storage/postgres/storage.go b/publicroomsapi/storage/postgres/storage.go index 8c4660cca..36c6aec64 100644 --- a/publicroomsapi/storage/postgres/storage.go +++ b/publicroomsapi/storage/postgres/storage.go @@ -20,7 +20,7 @@ import ( "database/sql" "encoding/json" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" @@ -29,21 +29,23 @@ import ( // PublicRoomsServerDatabase represents a public rooms server database. type PublicRoomsServerDatabase struct { db *sql.DB - common.PartitionOffsetStatements - statements publicRoomsStatements + sqlutil.PartitionOffsetStatements + statements publicRoomsStatements + localServerName gomatrixserverlib.ServerName } type attributeValue interface{} // NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { var db *sql.DB var err error - if db, err = sqlutil.Open("postgres", dataSourceName); err != nil { + if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } storage := PublicRoomsServerDatabase{ - db: db, + db: db, + localServerName: localServerName, } if err = storage.PartitionOffsetStatements.Prepare(db, "publicroomsapi"); err != nil { return nil, err @@ -136,33 +138,33 @@ func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent( case "m.room.aliases": return d.updateRoomAliases(ctx, event) case "m.room.canonical_alias": - var content common.CanonicalAliasContent + var content eventutil.CanonicalAliasContent field := &(content.Alias) attrName := "canonical_alias" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.name": - var content common.NameContent + var content eventutil.NameContent field := &(content.Name) attrName := "name" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.topic": - var content common.TopicContent + var content eventutil.TopicContent field := &(content.Topic) attrName := "topic" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.avatar": - var content common.AvatarContent + var content eventutil.AvatarContent field := &(content.URL) attrName := "avatar_url" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.history_visibility": - var content common.HistoryVisibilityContent + var content eventutil.HistoryVisibilityContent field := &(content.HistoryVisibility) attrName := "world_readable" strForTrue := "world_readable" return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) case "m.room.guest_access": - var content common.GuestAccessContent + var content eventutil.GuestAccessContent field := &(content.GuestAccess) attrName := "guest_can_join" strForTrue := "can_join" @@ -243,7 +245,10 @@ func (d *PublicRoomsServerDatabase) updateBooleanAttribute( func (d *PublicRoomsServerDatabase) updateRoomAliases( ctx context.Context, aliasesEvent gomatrixserverlib.Event, ) error { - var content common.AliasesContent + if aliasesEvent.StateKey() == nil || *aliasesEvent.StateKey() != string(d.localServerName) { + return nil // only store our own aliases + } + var content eventutil.AliasesContent if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { return err } diff --git a/publicroomsapi/storage/sqlite3/public_rooms_table.go b/publicroomsapi/storage/sqlite3/public_rooms_table.go index 44679837f..7b332e175 100644 --- a/publicroomsapi/storage/sqlite3/public_rooms_table.go +++ b/publicroomsapi/storage/sqlite3/public_rooms_table.go @@ -22,7 +22,7 @@ import ( "errors" "fmt" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) @@ -193,7 +193,7 @@ func (s *publicRoomsStatements) selectPublicRooms( if err != nil { return []gomatrixserverlib.PublicRoom{}, nil } - defer common.CloseAndLogIfError(ctx, rows, "selectPublicRooms failed to close rows") + defer internal.CloseAndLogIfError(ctx, rows, "selectPublicRooms failed to close rows") rooms := []gomatrixserverlib.PublicRoom{} for rows.Next() { diff --git a/publicroomsapi/storage/sqlite3/storage.go b/publicroomsapi/storage/sqlite3/storage.go index 121601628..5c685d131 100644 --- a/publicroomsapi/storage/sqlite3/storage.go +++ b/publicroomsapi/storage/sqlite3/storage.go @@ -22,7 +22,7 @@ import ( _ "github.com/mattn/go-sqlite3" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" @@ -31,21 +31,27 @@ import ( // PublicRoomsServerDatabase represents a public rooms server database. type PublicRoomsServerDatabase struct { db *sql.DB - common.PartitionOffsetStatements - statements publicRoomsStatements + sqlutil.PartitionOffsetStatements + statements publicRoomsStatements + localServerName gomatrixserverlib.ServerName } type attributeValue interface{} // NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { var db *sql.DB var err error - if db, err = sqlutil.Open(common.SQLiteDriverName(), dataSourceName); err != nil { + cs, err := sqlutil.ParseFileURI(dataSourceName) + if err != nil { + return nil, err + } + if db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } storage := PublicRoomsServerDatabase{ - db: db, + db: db, + localServerName: localServerName, } if err = storage.PartitionOffsetStatements.Prepare(db, "publicroomsapi"); err != nil { return nil, err @@ -138,33 +144,33 @@ func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent( case "m.room.aliases": return d.updateRoomAliases(ctx, event) case "m.room.canonical_alias": - var content common.CanonicalAliasContent + var content eventutil.CanonicalAliasContent field := &(content.Alias) attrName := "canonical_alias" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.name": - var content common.NameContent + var content eventutil.NameContent field := &(content.Name) attrName := "name" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.topic": - var content common.TopicContent + var content eventutil.TopicContent field := &(content.Topic) attrName := "topic" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.avatar": - var content common.AvatarContent + var content eventutil.AvatarContent field := &(content.URL) attrName := "avatar_url" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.history_visibility": - var content common.HistoryVisibilityContent + var content eventutil.HistoryVisibilityContent field := &(content.HistoryVisibility) attrName := "world_readable" strForTrue := "world_readable" return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) case "m.room.guest_access": - var content common.GuestAccessContent + var content eventutil.GuestAccessContent field := &(content.GuestAccess) attrName := "guest_can_join" strForTrue := "can_join" @@ -245,7 +251,10 @@ func (d *PublicRoomsServerDatabase) updateBooleanAttribute( func (d *PublicRoomsServerDatabase) updateRoomAliases( ctx context.Context, aliasesEvent gomatrixserverlib.Event, ) error { - var content common.AliasesContent + if aliasesEvent.StateKey() == nil || *aliasesEvent.StateKey() != string(d.localServerName) { + return nil // only store our own aliases + } + var content eventutil.AliasesContent if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { return err } diff --git a/publicroomsapi/storage/storage.go b/publicroomsapi/storage/storage.go index e674514aa..f66188040 100644 --- a/publicroomsapi/storage/storage.go +++ b/publicroomsapi/storage/storage.go @@ -19,25 +19,27 @@ package storage import ( "net/url" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/publicroomsapi/storage/postgres" "github.com/matrix-org/dendrite/publicroomsapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" ) const schemePostgres = "postgres" const schemeFile = "file" // NewPublicRoomsServerDatabase opens a database connection. -func NewPublicRoomsServerDatabase(dataSourceName string) (Database, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, localServerName gomatrixserverlib.ServerName) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.NewPublicRoomsServerDatabase(dataSourceName) + return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties, localServerName) } switch uri.Scheme { case schemePostgres: - return postgres.NewPublicRoomsServerDatabase(dataSourceName) + return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties, localServerName) case schemeFile: - return sqlite3.NewPublicRoomsServerDatabase(dataSourceName) + return sqlite3.NewPublicRoomsServerDatabase(dataSourceName, localServerName) default: - return postgres.NewPublicRoomsServerDatabase(dataSourceName) + return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties, localServerName) } } diff --git a/publicroomsapi/storage/storage_wasm.go b/publicroomsapi/storage/storage_wasm.go index d00c339d8..70ceeaf85 100644 --- a/publicroomsapi/storage/storage_wasm.go +++ b/publicroomsapi/storage/storage_wasm.go @@ -19,10 +19,11 @@ import ( "net/url" "github.com/matrix-org/dendrite/publicroomsapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" ) // NewPublicRoomsServerDatabase opens a database connection. -func NewPublicRoomsServerDatabase(dataSourceName string) (Database, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, localServerName gomatrixserverlib.ServerName) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return nil, err @@ -31,7 +32,7 @@ func NewPublicRoomsServerDatabase(dataSourceName string) (Database, error) { case "postgres": return nil, fmt.Errorf("Cannot use postgres implementation") case "file": - return sqlite3.NewPublicRoomsServerDatabase(dataSourceName) + return sqlite3.NewPublicRoomsServerDatabase(dataSourceName, localServerName) default: return nil, fmt.Errorf("Cannot use postgres implementation") } diff --git a/roomserver/alias/alias_test.go b/roomserver/alias/alias_test.go deleted file mode 100644 index 0aefa19d9..000000000 --- a/roomserver/alias/alias_test.go +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2019 Serra Allgood -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package alias - -import ( - "context" - "fmt" - "strings" - "testing" - - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" -) - -type MockRoomserverAliasAPIDatabase struct { - mode string - attempts int -} - -// These methods can be essentially noop -func (db MockRoomserverAliasAPIDatabase) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return nil -} - -func (db MockRoomserverAliasAPIDatabase) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - aliases := make([]string, 0) - return aliases, nil -} - -func (db MockRoomserverAliasAPIDatabase) RemoveRoomAlias(ctx context.Context, alias string) error { - return nil -} - -func (db *MockRoomserverAliasAPIDatabase) GetCreatorIDForAlias( - ctx context.Context, alias string, -) (string, error) { - return "", nil -} - -func (db *MockRoomserverAliasAPIDatabase) GetRoomVersionForRoom( - ctx context.Context, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - return gomatrixserverlib.RoomVersionV1, nil -} - -// This method needs to change depending on test case -func (db *MockRoomserverAliasAPIDatabase) GetRoomIDForAlias( - ctx context.Context, - alias string, -) (string, error) { - switch db.mode { - case "empty": - return "", nil - case "error": - return "", fmt.Errorf("found an error from GetRoomIDForAlias") - case "found": - return "123", nil - case "emptyFound": - switch db.attempts { - case 0: - db.attempts = 1 - return "", nil - case 1: - db.attempts = 0 - return "123", nil - default: - return "", nil - } - default: - return "", fmt.Errorf("unknown option used") - } -} - -type MockAppServiceQueryAPI struct { - mode string -} - -// This method can be noop -func (q MockAppServiceQueryAPI) UserIDExists( - ctx context.Context, - req *appserviceAPI.UserIDExistsRequest, - resp *appserviceAPI.UserIDExistsResponse, -) error { - return nil -} - -func (q MockAppServiceQueryAPI) RoomAliasExists( - ctx context.Context, - req *appserviceAPI.RoomAliasExistsRequest, - resp *appserviceAPI.RoomAliasExistsResponse, -) error { - switch q.mode { - case "error": - return fmt.Errorf("found an error from RoomAliasExists") - case "found": - resp.AliasExists = true - return nil - case "empty": - resp.AliasExists = false - return nil - default: - return fmt.Errorf("Unknown option used") - } -} - -func TestGetRoomIDForAlias(t *testing.T) { - type arguments struct { - ctx context.Context - request *roomserverAPI.GetRoomIDForAliasRequest - response *roomserverAPI.GetRoomIDForAliasResponse - } - args := arguments{ - context.Background(), - &roomserverAPI.GetRoomIDForAliasRequest{}, - &roomserverAPI.GetRoomIDForAliasResponse{}, - } - type testCase struct { - name string - dbMode string - queryMode string - wantError bool - errorMsg string - want string - } - tt := []testCase{ - { - "found local alias", - "found", - "error", - false, - "", - "123", - }, - { - "found appservice alias", - "emptyFound", - "found", - false, - "", - "123", - }, - { - "error returned from DB", - "error", - "", - true, - "GetRoomIDForAlias", - "", - }, - { - "error returned from appserviceAPI", - "empty", - "error", - true, - "RoomAliasExists", - "", - }, - { - "no errors but no alias", - "empty", - "empty", - false, - "", - "", - }, - } - - setup := func(dbMode, queryMode string) *RoomserverAliasAPI { - mockAliasAPIDB := &MockRoomserverAliasAPIDatabase{dbMode, 0} - mockAppServiceQueryAPI := MockAppServiceQueryAPI{queryMode} - - return &RoomserverAliasAPI{ - DB: mockAliasAPIDB, - AppserviceAPI: mockAppServiceQueryAPI, - } - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - aliasAPI := setup(tc.dbMode, tc.queryMode) - - err := aliasAPI.GetRoomIDForAlias(args.ctx, args.request, args.response) - if tc.wantError { - if err == nil { - t.Fatalf("Got no error; wanted error from %s", tc.errorMsg) - } else if !strings.Contains(err.Error(), tc.errorMsg) { - t.Fatalf("Got %s; wanted error from %s", err, tc.errorMsg) - } - } else if err != nil { - t.Fatalf("Got %s; wanted no error", err) - } else if args.response.RoomID != tc.want { - t.Errorf("Got '%s'; wanted '%s'", args.response.RoomID, tc.want) - } - }) - } -} diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index ad375a830..61fdc6116 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -14,15 +14,6 @@ package api -import ( - "context" - "errors" - "net/http" - - commonHTTP "github.com/matrix-org/dendrite/common/http" - opentracing "github.com/opentracing/opentracing-go" -) - // SetRoomAliasRequest is a request to SetRoomAlias type SetRoomAliasRequest struct { // ID of the user setting the alias @@ -85,135 +76,3 @@ type RemoveRoomAliasRequest struct { // RemoveRoomAliasResponse is a response to RemoveRoomAlias type RemoveRoomAliasResponse struct{} - -// RoomserverAliasAPI is used to save, lookup or remove a room alias -type RoomserverAliasAPI interface { - // Set a room alias - SetRoomAlias( - ctx context.Context, - req *SetRoomAliasRequest, - response *SetRoomAliasResponse, - ) error - - // Get the room ID for an alias - GetRoomIDForAlias( - ctx context.Context, - req *GetRoomIDForAliasRequest, - response *GetRoomIDForAliasResponse, - ) error - - // Get all known aliases for a room ID - GetAliasesForRoomID( - ctx context.Context, - req *GetAliasesForRoomIDRequest, - response *GetAliasesForRoomIDResponse, - ) error - - // Get the user ID of the creator of an alias - GetCreatorIDForAlias( - ctx context.Context, - req *GetCreatorIDForAliasRequest, - response *GetCreatorIDForAliasResponse, - ) error - - // Remove a room alias - RemoveRoomAlias( - ctx context.Context, - req *RemoveRoomAliasRequest, - response *RemoveRoomAliasResponse, - ) error -} - -// RoomserverSetRoomAliasPath is the HTTP path for the SetRoomAlias API. -const RoomserverSetRoomAliasPath = "/api/roomserver/setRoomAlias" - -// RoomserverGetRoomIDForAliasPath is the HTTP path for the GetRoomIDForAlias API. -const RoomserverGetRoomIDForAliasPath = "/api/roomserver/GetRoomIDForAlias" - -// RoomserverGetAliasesForRoomIDPath is the HTTP path for the GetAliasesForRoomID API. -const RoomserverGetAliasesForRoomIDPath = "/api/roomserver/GetAliasesForRoomID" - -// RoomserverGetCreatorIDForAliasPath is the HTTP path for the GetCreatorIDForAlias API. -const RoomserverGetCreatorIDForAliasPath = "/api/roomserver/GetCreatorIDForAlias" - -// RoomserverRemoveRoomAliasPath is the HTTP path for the RemoveRoomAlias API. -const RoomserverRemoveRoomAliasPath = "/api/roomserver/removeRoomAlias" - -// NewRoomserverAliasAPIHTTP creates a RoomserverAliasAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewRoomserverAliasAPIHTTP(roomserverURL string, httpClient *http.Client) (RoomserverAliasAPI, error) { - if httpClient == nil { - return nil, errors.New("NewRoomserverAliasAPIHTTP: httpClient is ") - } - return &httpRoomserverAliasAPI{roomserverURL, httpClient}, nil -} - -type httpRoomserverAliasAPI struct { - roomserverURL string - httpClient *http.Client -} - -// SetRoomAlias implements RoomserverAliasAPI -func (h *httpRoomserverAliasAPI) SetRoomAlias( - ctx context.Context, - request *SetRoomAliasRequest, - response *SetRoomAliasResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverSetRoomAliasPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// GetRoomIDForAlias implements RoomserverAliasAPI -func (h *httpRoomserverAliasAPI) GetRoomIDForAlias( - ctx context.Context, - request *GetRoomIDForAliasRequest, - response *GetRoomIDForAliasResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// GetAliasesForRoomID implements RoomserverAliasAPI -func (h *httpRoomserverAliasAPI) GetAliasesForRoomID( - ctx context.Context, - request *GetAliasesForRoomIDRequest, - response *GetAliasesForRoomIDResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// GetCreatorIDForAlias implements RoomserverAliasAPI -func (h *httpRoomserverAliasAPI) GetCreatorIDForAlias( - ctx context.Context, - request *GetCreatorIDForAliasRequest, - response *GetCreatorIDForAliasResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetCreatorIDForAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverGetCreatorIDForAliasPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// RemoveRoomAlias implements RoomserverAliasAPI -func (h *httpRoomserverAliasAPI) RemoveRoomAlias( - ctx context.Context, - request *RemoveRoomAliasRequest, - response *RemoveRoomAliasResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} diff --git a/roomserver/api/api.go b/roomserver/api/api.go new file mode 100644 index 000000000..26ec8ca1d --- /dev/null +++ b/roomserver/api/api.go @@ -0,0 +1,152 @@ +package api + +import ( + "context" + + fsAPI "github.com/matrix-org/dendrite/federationsender/api" +) + +// RoomserverInputAPI is used to write events to the room server. +type RoomserverInternalAPI interface { + // needed to avoid chicken and egg scenario when setting up the + // interdependencies between the roomserver and other input APIs + SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) + + InputRoomEvents( + ctx context.Context, + request *InputRoomEventsRequest, + response *InputRoomEventsResponse, + ) error + + PerformInvite( + ctx context.Context, + req *PerformInviteRequest, + res *PerformInviteResponse, + ) + + PerformJoin( + ctx context.Context, + req *PerformJoinRequest, + res *PerformJoinResponse, + ) + + PerformLeave( + ctx context.Context, + req *PerformLeaveRequest, + res *PerformLeaveResponse, + ) error + + // Query the latest events and state for a room from the room server. + QueryLatestEventsAndState( + ctx context.Context, + request *QueryLatestEventsAndStateRequest, + response *QueryLatestEventsAndStateResponse, + ) error + + // Query the state after a list of events in a room from the room server. + QueryStateAfterEvents( + ctx context.Context, + request *QueryStateAfterEventsRequest, + response *QueryStateAfterEventsResponse, + ) error + + // Query a list of events by event ID. + QueryEventsByID( + ctx context.Context, + request *QueryEventsByIDRequest, + response *QueryEventsByIDResponse, + ) error + + // Query the membership event for an user for a room. + QueryMembershipForUser( + ctx context.Context, + request *QueryMembershipForUserRequest, + response *QueryMembershipForUserResponse, + ) error + + // Query a list of membership events for a room + QueryMembershipsForRoom( + ctx context.Context, + request *QueryMembershipsForRoomRequest, + response *QueryMembershipsForRoomResponse, + ) error + + // Query whether a server is allowed to see an event + QueryServerAllowedToSeeEvent( + ctx context.Context, + request *QueryServerAllowedToSeeEventRequest, + response *QueryServerAllowedToSeeEventResponse, + ) error + + // Query missing events for a room from roomserver + QueryMissingEvents( + ctx context.Context, + request *QueryMissingEventsRequest, + response *QueryMissingEventsResponse, + ) error + + // Query to get state and auth chain for a (potentially hypothetical) event. + // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate + // the state and auth chain to return. + QueryStateAndAuthChain( + ctx context.Context, + request *QueryStateAndAuthChainRequest, + response *QueryStateAndAuthChainResponse, + ) error + + // Query a given amount (or less) of events prior to a given set of events. + PerformBackfill( + ctx context.Context, + request *PerformBackfillRequest, + response *PerformBackfillResponse, + ) error + + // Asks for the default room version as preferred by the server. + QueryRoomVersionCapabilities( + ctx context.Context, + request *QueryRoomVersionCapabilitiesRequest, + response *QueryRoomVersionCapabilitiesResponse, + ) error + + // Asks for the room version for a given room. + QueryRoomVersionForRoom( + ctx context.Context, + request *QueryRoomVersionForRoomRequest, + response *QueryRoomVersionForRoomResponse, + ) error + + // Set a room alias + SetRoomAlias( + ctx context.Context, + req *SetRoomAliasRequest, + response *SetRoomAliasResponse, + ) error + + // Get the room ID for an alias + GetRoomIDForAlias( + ctx context.Context, + req *GetRoomIDForAliasRequest, + response *GetRoomIDForAliasResponse, + ) error + + // Get all known aliases for a room ID + GetAliasesForRoomID( + ctx context.Context, + req *GetAliasesForRoomIDRequest, + response *GetAliasesForRoomIDResponse, + ) error + + // Get the user ID of the creator of an alias + GetCreatorIDForAlias( + ctx context.Context, + req *GetCreatorIDForAliasRequest, + response *GetCreatorIDForAliasResponse, + ) error + + // Remove a room alias + RemoveRoomAlias( + ctx context.Context, + req *RemoveRoomAliasRequest, + response *RemoveRoomAliasResponse, + ) error +} diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go new file mode 100644 index 000000000..8645b6f28 --- /dev/null +++ b/roomserver/api/api_trace.go @@ -0,0 +1,226 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + + fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/util" +) + +// RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the +// complete request/response/error +type RoomserverInternalAPITrace struct { + Impl RoomserverInternalAPI +} + +func (t *RoomserverInternalAPITrace) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) { + t.Impl.SetFederationSenderAPI(fsAPI) +} + +func (t *RoomserverInternalAPITrace) InputRoomEvents( + ctx context.Context, + req *InputRoomEventsRequest, + res *InputRoomEventsResponse, +) error { + err := t.Impl.InputRoomEvents(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) PerformInvite( + ctx context.Context, + req *PerformInviteRequest, + res *PerformInviteResponse, +) { + t.Impl.PerformInvite(ctx, req, res) + util.GetLogger(ctx).Infof("PerformInvite req=%+v res=%+v", js(req), js(res)) +} + +func (t *RoomserverInternalAPITrace) PerformJoin( + ctx context.Context, + req *PerformJoinRequest, + res *PerformJoinResponse, +) { + t.Impl.PerformJoin(ctx, req, res) + util.GetLogger(ctx).Infof("PerformJoin req=%+v res=%+v", js(req), js(res)) +} + +func (t *RoomserverInternalAPITrace) PerformLeave( + ctx context.Context, + req *PerformLeaveRequest, + res *PerformLeaveResponse, +) error { + err := t.Impl.PerformLeave(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformLeave req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryLatestEventsAndState( + ctx context.Context, + req *QueryLatestEventsAndStateRequest, + res *QueryLatestEventsAndStateResponse, +) error { + err := t.Impl.QueryLatestEventsAndState(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryLatestEventsAndState req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryStateAfterEvents( + ctx context.Context, + req *QueryStateAfterEventsRequest, + res *QueryStateAfterEventsResponse, +) error { + err := t.Impl.QueryStateAfterEvents(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryStateAfterEvents req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryEventsByID( + ctx context.Context, + req *QueryEventsByIDRequest, + res *QueryEventsByIDResponse, +) error { + err := t.Impl.QueryEventsByID(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryEventsByID req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryMembershipForUser( + ctx context.Context, + req *QueryMembershipForUserRequest, + res *QueryMembershipForUserResponse, +) error { + err := t.Impl.QueryMembershipForUser(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryMembershipForUser req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryMembershipsForRoom( + ctx context.Context, + req *QueryMembershipsForRoomRequest, + res *QueryMembershipsForRoomResponse, +) error { + err := t.Impl.QueryMembershipsForRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryMembershipsForRoom req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryServerAllowedToSeeEvent( + ctx context.Context, + req *QueryServerAllowedToSeeEventRequest, + res *QueryServerAllowedToSeeEventResponse, +) error { + err := t.Impl.QueryServerAllowedToSeeEvent(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryServerAllowedToSeeEvent req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryMissingEvents( + ctx context.Context, + req *QueryMissingEventsRequest, + res *QueryMissingEventsResponse, +) error { + err := t.Impl.QueryMissingEvents(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryMissingEvents req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryStateAndAuthChain( + ctx context.Context, + req *QueryStateAndAuthChainRequest, + res *QueryStateAndAuthChainResponse, +) error { + err := t.Impl.QueryStateAndAuthChain(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryStateAndAuthChain req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) PerformBackfill( + ctx context.Context, + req *PerformBackfillRequest, + res *PerformBackfillResponse, +) error { + err := t.Impl.PerformBackfill(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformBackfill req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryRoomVersionCapabilities( + ctx context.Context, + req *QueryRoomVersionCapabilitiesRequest, + res *QueryRoomVersionCapabilitiesResponse, +) error { + err := t.Impl.QueryRoomVersionCapabilities(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryRoomVersionCapabilities req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryRoomVersionForRoom( + ctx context.Context, + req *QueryRoomVersionForRoomRequest, + res *QueryRoomVersionForRoomResponse, +) error { + err := t.Impl.QueryRoomVersionForRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryRoomVersionForRoom req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) SetRoomAlias( + ctx context.Context, + req *SetRoomAliasRequest, + res *SetRoomAliasResponse, +) error { + err := t.Impl.SetRoomAlias(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("SetRoomAlias req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) GetRoomIDForAlias( + ctx context.Context, + req *GetRoomIDForAliasRequest, + res *GetRoomIDForAliasResponse, +) error { + err := t.Impl.GetRoomIDForAlias(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("GetRoomIDForAlias req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) GetAliasesForRoomID( + ctx context.Context, + req *GetAliasesForRoomIDRequest, + res *GetAliasesForRoomIDResponse, +) error { + err := t.Impl.GetAliasesForRoomID(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("GetAliasesForRoomID req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) GetCreatorIDForAlias( + ctx context.Context, + req *GetCreatorIDForAliasRequest, + res *GetCreatorIDForAliasResponse, +) error { + err := t.Impl.GetCreatorIDForAlias(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("GetCreatorIDForAlias req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) RemoveRoomAlias( + ctx context.Context, + req *RemoveRoomAliasRequest, + res *RemoveRoomAliasResponse, +) error { + err := t.Impl.RemoveRoomAlias(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("RemoveRoomAlias req=%+v res=%+v", js(req), js(res)) + return err +} + +func js(thing interface{}) string { + b, err := json.Marshal(thing) + if err != nil { + return fmt.Sprintf("Marshal error:%s", err) + } + return string(b) +} diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 87e3983e3..05c981df4 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -16,13 +16,7 @@ package api import ( - "context" - "errors" - "net/http" - - commonHTTP "github.com/matrix-org/dendrite/common/http" "github.com/matrix-org/gomatrixserverlib" - opentracing "github.com/opentracing/opentracing-go" ) const ( @@ -82,61 +76,12 @@ type TransactionID struct { TransactionID string `json:"id"` } -// InputInviteEvent is a matrix invite event received over federation without -// the usual context a matrix room event would have. We usually do not have -// access to the events needed to check the event auth rules for the invite. -type InputInviteEvent struct { - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - Event gomatrixserverlib.HeaderedEvent `json:"event"` - InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"` -} - // InputRoomEventsRequest is a request to InputRoomEvents type InputRoomEventsRequest struct { - InputRoomEvents []InputRoomEvent `json:"input_room_events"` - InputInviteEvents []InputInviteEvent `json:"input_invite_events"` + InputRoomEvents []InputRoomEvent `json:"input_room_events"` } // InputRoomEventsResponse is a response to InputRoomEvents type InputRoomEventsResponse struct { EventID string `json:"event_id"` } - -// RoomserverInputAPI is used to write events to the room server. -type RoomserverInputAPI interface { - InputRoomEvents( - ctx context.Context, - request *InputRoomEventsRequest, - response *InputRoomEventsResponse, - ) error -} - -// RoomserverInputRoomEventsPath is the HTTP path for the InputRoomEvents API. -const RoomserverInputRoomEventsPath = "/api/roomserver/inputRoomEvents" - -// NewRoomserverInputAPIHTTP creates a RoomserverInputAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewRoomserverInputAPIHTTP(roomserverURL string, httpClient *http.Client) (RoomserverInputAPI, error) { - if httpClient == nil { - return nil, errors.New("NewRoomserverInputAPIHTTP: httpClient is ") - } - return &httpRoomserverInputAPI{roomserverURL, httpClient}, nil -} - -type httpRoomserverInputAPI struct { - roomserverURL string - httpClient *http.Client -} - -// InputRoomEvents implements RoomserverInputAPI -func (h *httpRoomserverInputAPI) InputRoomEvents( - ctx context.Context, - request *InputRoomEventsRequest, - response *InputRoomEventsResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverInputRoomEventsPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 92a468a96..2bbd97af8 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -63,6 +63,13 @@ type OutputNewRoomEvent struct { // Together with RemovesStateEventIDs this allows the receiver to keep an up to date // view of the current state of the room. AddsStateEventIDs []string `json:"adds_state_event_ids"` + // All extra newly added state events. This is only set if there are *extra* events + // other than `Event`. This can happen when forks get merged because state resolution + // may decide a bunch of state events on one branch are now valid, so they will be + // present in this list. This is useful when trying to maintain the current state of a room + // as to do so you need to include both these events and `Event`. + AddStateEvents []gomatrixserverlib.HeaderedEvent `json:"adds_state_events"` + // The state event IDs that were removed from the state of the room by this event. RemovesStateEventIDs []string `json:"removes_state_event_ids"` // The ID of the event that was output before this event. @@ -112,6 +119,26 @@ type OutputNewRoomEvent struct { TransactionID *TransactionID `json:"transaction_id"` } +// AddsState returns all added state events from this event. +// +// This function is needed because `AddStateEvents` will not include a copy of +// the original event to save space, so you cannot use that slice alone. +// Instead, use this function which will add the original event if it is present +// in `AddsStateEventIDs`. +func (ore *OutputNewRoomEvent) AddsState() []gomatrixserverlib.HeaderedEvent { + includeOutputEvent := false + for _, id := range ore.AddsStateEventIDs { + if id == ore.Event.EventID() { + includeOutputEvent = true + break + } + } + if !includeOutputEvent { + return ore.AddStateEvents + } + return append(ore.AddStateEvents, ore.Event) +} + // An OutputNewInviteEvent is written whenever an invite becomes active. // Invite events can be received outside of an existing room so have to be // tracked separately from the room events themselves. diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go new file mode 100644 index 000000000..0b8e6df25 --- /dev/null +++ b/roomserver/api/perform.go @@ -0,0 +1,118 @@ +package api + +import ( + "fmt" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +type PerformErrorCode int + +type PerformError struct { + Msg string + Code PerformErrorCode +} + +func (p *PerformError) Error() string { + return fmt.Sprintf("%d : %s", p.Code, p.Msg) +} + +// JSONResponse maps error codes to suitable HTTP error codes, defaulting to 500. +func (p *PerformError) JSONResponse() util.JSONResponse { + switch p.Code { + case PerformErrorBadRequest: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Unknown(p.Msg), + } + case PerformErrorNoRoom: + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound(p.Msg), + } + case PerformErrorNotAllowed: + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(p.Msg), + } + default: + return util.ErrorResponse(p) + } +} + +const ( + // PerformErrorNotAllowed means the user is not allowed to invite/join/etc this room (e.g join_rule:invite or banned) + PerformErrorNotAllowed PerformErrorCode = 1 + // PerformErrorBadRequest means the request was wrong in some way (invalid user ID, wrong server, etc) + PerformErrorBadRequest PerformErrorCode = 2 + // PerformErrorNoRoom means that the room being joined doesn't exist. + PerformErrorNoRoom PerformErrorCode = 3 + // PerformErrorNoOperation means that the request resulted in nothing happening e.g invite->invite or leave->leave. + PerformErrorNoOperation PerformErrorCode = 4 +) + +type PerformJoinRequest struct { + RoomIDOrAlias string `json:"room_id_or_alias"` + UserID string `json:"user_id"` + Content map[string]interface{} `json:"content"` + ServerNames []gomatrixserverlib.ServerName `json:"server_names"` +} + +type PerformJoinResponse struct { + // The room ID, populated on success. + RoomID string `json:"room_id"` + // If non-nil, the join request failed. Contains more information why it failed. + Error *PerformError +} + +type PerformLeaveRequest struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` +} + +type PerformLeaveResponse struct { +} + +type PerformInviteRequest struct { + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + Event gomatrixserverlib.HeaderedEvent `json:"event"` + InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"` + SendAsServer string `json:"send_as_server"` + TransactionID *TransactionID `json:"transaction_id"` +} + +type PerformInviteResponse struct { + // If non-nil, the invite request failed. Contains more information why it failed. + Error *PerformError +} + +// PerformBackfillRequest is a request to PerformBackfill. +type PerformBackfillRequest struct { + // The room to backfill + RoomID string `json:"room_id"` + // A map of backwards extremity event ID to a list of its prev_event IDs. + BackwardsExtremities map[string][]string `json:"backwards_extremities"` + // The maximum number of events to retrieve. + Limit int `json:"limit"` + // The server interested in the events. + ServerName gomatrixserverlib.ServerName `json:"server_name"` +} + +// PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. +func (r *PerformBackfillRequest) PrevEventIDs() []string { + var prevEventIDs []string + for _, pes := range r.BackwardsExtremities { + prevEventIDs = append(prevEventIDs, pes...) + } + prevEventIDs = util.UniqueStrings(prevEventIDs) + return prevEventIDs +} + +// PerformBackfillResponse is a response to PerformBackfill. +type PerformBackfillResponse struct { + // Missing events, arbritrary order. + Events []gomatrixserverlib.HeaderedEvent `json:"events"` +} diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b272b1ebd..6586b1af3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -17,14 +17,7 @@ package api import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/common/caching" - commonHTTP "github.com/matrix-org/dendrite/common/http" "github.com/matrix-org/gomatrixserverlib" - opentracing "github.com/opentracing/opentracing-go" ) // QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState @@ -32,7 +25,7 @@ type QueryLatestEventsAndStateRequest struct { // The room ID to query the latest events for. RoomID string `json:"room_id"` // The state key tuples to fetch from the room current state. - // If this list is empty or nil then no state events are returned. + // If this list is empty or nil then *ALL* current state events are returned. StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"` } @@ -40,8 +33,6 @@ type QueryLatestEventsAndStateRequest struct { // This is used when sending events to set the prev_events, auth_events and depth. // It is also used to tell whether the event is allowed by the event auth rules. type QueryLatestEventsAndStateResponse struct { - // Copy of the request for debugging. - QueryLatestEventsAndStateRequest // Does the room exist? // If the room doesn't exist this will be false and LatestEvents will be empty. RoomExists bool `json:"room_exists"` @@ -73,8 +64,6 @@ type QueryStateAfterEventsRequest struct { // QueryStateAfterEventsResponse is a response to QueryStateAfterEvents type QueryStateAfterEventsResponse struct { - // Copy of the request for debugging. - QueryStateAfterEventsRequest // Does the room exist on this roomserver? // If the room doesn't exist this will be false and StateEvents will be empty. RoomExists bool `json:"room_exists"` @@ -96,8 +85,6 @@ type QueryEventsByIDRequest struct { // QueryEventsByIDResponse is a response to QueryEventsByID type QueryEventsByIDResponse struct { - // Copy of the request for debugging. - QueryEventsByIDRequest // A list of events with the requested IDs. // If the roomserver does not have a copy of a requested event // then it will omit that event from the list. @@ -146,23 +133,6 @@ type QueryMembershipsForRoomResponse struct { HasBeenInRoom bool `json:"has_been_in_room"` } -// QueryInvitesForUserRequest is a request to QueryInvitesForUser -type QueryInvitesForUserRequest struct { - // The room ID to look up invites in. - RoomID string `json:"room_id"` - // The User ID to look up invites for. - TargetUserID string `json:"target_user_id"` -} - -// QueryInvitesForUserResponse is a response to QueryInvitesForUser -// This is used when accepting an invite or rejecting a invite to tell which -// remote matrix servers to contact. -type QueryInvitesForUserResponse struct { - // A list of matrix user IDs for each sender of an active invite targeting - // the requested user ID. - InviteSenderUserIDs []string `json:"invite_sender_user_ids"` -} - // QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent type QueryServerAllowedToSeeEventRequest struct { // The event ID to look up invites in. @@ -211,8 +181,6 @@ type QueryStateAndAuthChainRequest struct { // QueryStateAndAuthChainResponse is a response to QueryStateAndAuthChain type QueryStateAndAuthChainResponse struct { - // Copy of the request for debugging. - QueryStateAndAuthChainRequest // Does the room exist on this roomserver? // If the room doesn't exist this will be false and StateEvents will be empty. RoomExists bool `json:"room_exists"` @@ -227,37 +195,7 @@ type QueryStateAndAuthChainResponse struct { AuthChainEvents []gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"` } -// QueryBackfillRequest is a request to QueryBackfill. -type QueryBackfillRequest struct { - // Events to start paginating from. - EarliestEventsIDs []string `json:"earliest_event_ids"` - // The maximum number of events to retrieve. - Limit int `json:"limit"` - // The server interested in the events. - ServerName gomatrixserverlib.ServerName `json:"server_name"` -} - -// QueryBackfillResponse is a response to QueryBackfill. -type QueryBackfillResponse struct { - // Missing events, arbritrary order. - Events []gomatrixserverlib.HeaderedEvent `json:"events"` -} - -// QueryServersInRoomAtEventRequest is a request to QueryServersInRoomAtEvent -type QueryServersInRoomAtEventRequest struct { - // ID of the room to retrieve member servers for. - RoomID string `json:"room_id"` - // ID of the event for which to retrieve member servers. - EventID string `json:"event_id"` -} - -// QueryServersInRoomAtEventResponse is a response to QueryServersInRoomAtEvent -type QueryServersInRoomAtEventResponse struct { - // Servers present in the room for these events. - Servers []gomatrixserverlib.ServerName `json:"servers"` -} - -// QueryRoomVersionCapabilities asks for the default room version +// QueryRoomVersionCapabilitiesRequest asks for the default room version type QueryRoomVersionCapabilitiesRequest struct{} // QueryRoomVersionCapabilitiesResponse is a response to QueryRoomVersionCapabilitiesRequest @@ -266,339 +204,12 @@ type QueryRoomVersionCapabilitiesResponse struct { AvailableRoomVersions map[gomatrixserverlib.RoomVersion]string `json:"available"` } -// QueryRoomVersionForRoom asks for the room version for a given room. +// QueryRoomVersionForRoomRequest asks for the room version for a given room. type QueryRoomVersionForRoomRequest struct { RoomID string `json:"room_id"` } -// QueryRoomVersionCapabilitiesResponse is a response to QueryServersInRoomAtEventResponse +// QueryRoomVersionForRoomResponse is a response to QueryRoomVersionForRoomRequest type QueryRoomVersionForRoomResponse struct { RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` } - -// RoomserverQueryAPI is used to query information from the room server. -type RoomserverQueryAPI interface { - // Query the latest events and state for a room from the room server. - QueryLatestEventsAndState( - ctx context.Context, - request *QueryLatestEventsAndStateRequest, - response *QueryLatestEventsAndStateResponse, - ) error - - // Query the state after a list of events in a room from the room server. - QueryStateAfterEvents( - ctx context.Context, - request *QueryStateAfterEventsRequest, - response *QueryStateAfterEventsResponse, - ) error - - // Query a list of events by event ID. - QueryEventsByID( - ctx context.Context, - request *QueryEventsByIDRequest, - response *QueryEventsByIDResponse, - ) error - - // Query the membership event for an user for a room. - QueryMembershipForUser( - ctx context.Context, - request *QueryMembershipForUserRequest, - response *QueryMembershipForUserResponse, - ) error - - // Query a list of membership events for a room - QueryMembershipsForRoom( - ctx context.Context, - request *QueryMembershipsForRoomRequest, - response *QueryMembershipsForRoomResponse, - ) error - - // Query a list of invite event senders for a user in a room. - QueryInvitesForUser( - ctx context.Context, - request *QueryInvitesForUserRequest, - response *QueryInvitesForUserResponse, - ) error - - // Query whether a server is allowed to see an event - QueryServerAllowedToSeeEvent( - ctx context.Context, - request *QueryServerAllowedToSeeEventRequest, - response *QueryServerAllowedToSeeEventResponse, - ) error - - // Query missing events for a room from roomserver - QueryMissingEvents( - ctx context.Context, - request *QueryMissingEventsRequest, - response *QueryMissingEventsResponse, - ) error - - // Query to get state and auth chain for a (potentially hypothetical) event. - // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate - // the state and auth chain to return. - QueryStateAndAuthChain( - ctx context.Context, - request *QueryStateAndAuthChainRequest, - response *QueryStateAndAuthChainResponse, - ) error - - // Query a given amount (or less) of events prior to a given set of events. - QueryBackfill( - ctx context.Context, - request *QueryBackfillRequest, - response *QueryBackfillResponse, - ) error - - QueryServersInRoomAtEvent( - ctx context.Context, - request *QueryServersInRoomAtEventRequest, - response *QueryServersInRoomAtEventResponse, - ) error - - // Asks for the default room version as preferred by the server. - QueryRoomVersionCapabilities( - ctx context.Context, - request *QueryRoomVersionCapabilitiesRequest, - response *QueryRoomVersionCapabilitiesResponse, - ) error - - // Asks for the room version for a given room. - QueryRoomVersionForRoom( - ctx context.Context, - request *QueryRoomVersionForRoomRequest, - response *QueryRoomVersionForRoomResponse, - ) error -} - -// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. -const RoomserverQueryLatestEventsAndStatePath = "/api/roomserver/queryLatestEventsAndState" - -// RoomserverQueryStateAfterEventsPath is the HTTP path for the QueryStateAfterEvents API. -const RoomserverQueryStateAfterEventsPath = "/api/roomserver/queryStateAfterEvents" - -// RoomserverQueryEventsByIDPath is the HTTP path for the QueryEventsByID API. -const RoomserverQueryEventsByIDPath = "/api/roomserver/queryEventsByID" - -// RoomserverQueryMembershipForUserPath is the HTTP path for the QueryMembershipForUser API. -const RoomserverQueryMembershipForUserPath = "/api/roomserver/queryMembershipForUser" - -// RoomserverQueryMembershipsForRoomPath is the HTTP path for the QueryMembershipsForRoom API -const RoomserverQueryMembershipsForRoomPath = "/api/roomserver/queryMembershipsForRoom" - -// RoomserverQueryInvitesForUserPath is the HTTP path for the QueryInvitesForUser API -const RoomserverQueryInvitesForUserPath = "/api/roomserver/queryInvitesForUser" - -// RoomserverQueryServerAllowedToSeeEventPath is the HTTP path for the QueryServerAllowedToSeeEvent API -const RoomserverQueryServerAllowedToSeeEventPath = "/api/roomserver/queryServerAllowedToSeeEvent" - -// RoomserverQueryMissingEventsPath is the HTTP path for the QueryMissingEvents API -const RoomserverQueryMissingEventsPath = "/api/roomserver/queryMissingEvents" - -// RoomserverQueryStateAndAuthChainPath is the HTTP path for the QueryStateAndAuthChain API -const RoomserverQueryStateAndAuthChainPath = "/api/roomserver/queryStateAndAuthChain" - -// RoomserverQueryBackfillPath is the HTTP path for the QueryBackfillPath API -const RoomserverQueryBackfillPath = "/api/roomserver/queryBackfill" - -// RoomserverQueryServersInRoomAtEventPath is the HTTP path for the QueryServersInRoomAtEvent API -const RoomserverQueryServersInRoomAtEventPath = "/api/roomserver/queryServersInRoomAtEvents" - -// RoomserverQueryRoomVersionCapabilitiesPath is the HTTP path for the QueryRoomVersionCapabilities API -const RoomserverQueryRoomVersionCapabilitiesPath = "/api/roomserver/queryRoomVersionCapabilities" - -// RoomserverQueryRoomVersionCapabilitiesPath is the HTTP path for the QueryRoomVersionCapabilities API -const RoomserverQueryRoomVersionForRoomPath = "/api/roomserver/queryRoomVersionForRoom" - -// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client, cache caching.ImmutableCache) (RoomserverQueryAPI, error) { - if httpClient == nil { - return nil, errors.New("NewRoomserverQueryAPIHTTP: httpClient is ") - } - return &httpRoomserverQueryAPI{roomserverURL, httpClient, cache}, nil -} - -type httpRoomserverQueryAPI struct { - roomserverURL string - httpClient *http.Client - immutableCache caching.ImmutableCache -} - -// QueryLatestEventsAndState implements RoomserverQueryAPI -func (h *httpRoomserverQueryAPI) QueryLatestEventsAndState( - ctx context.Context, - request *QueryLatestEventsAndStateRequest, - response *QueryLatestEventsAndStateResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryStateAfterEvents implements RoomserverQueryAPI -func (h *httpRoomserverQueryAPI) QueryStateAfterEvents( - ctx context.Context, - request *QueryStateAfterEventsRequest, - response *QueryStateAfterEventsResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryEventsByID implements RoomserverQueryAPI -func (h *httpRoomserverQueryAPI) QueryEventsByID( - ctx context.Context, - request *QueryEventsByIDRequest, - response *QueryEventsByIDResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryMembershipForUser implements RoomserverQueryAPI -func (h *httpRoomserverQueryAPI) QueryMembershipForUser( - ctx context.Context, - request *QueryMembershipForUserRequest, - response *QueryMembershipForUserResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryMembershipsForRoom implements RoomserverQueryAPI -func (h *httpRoomserverQueryAPI) QueryMembershipsForRoom( - ctx context.Context, - request *QueryMembershipsForRoomRequest, - response *QueryMembershipsForRoomResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryInvitesForUser implements RoomserverQueryAPI -func (h *httpRoomserverQueryAPI) QueryInvitesForUser( - ctx context.Context, - request *QueryInvitesForUserRequest, - response *QueryInvitesForUserResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryInvitesForUser") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryInvitesForUserPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI -func (h *httpRoomserverQueryAPI) QueryServerAllowedToSeeEvent( - ctx context.Context, - request *QueryServerAllowedToSeeEventRequest, - response *QueryServerAllowedToSeeEventResponse, -) (err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryMissingEvents implements RoomServerQueryAPI -func (h *httpRoomserverQueryAPI) QueryMissingEvents( - ctx context.Context, - request *QueryMissingEventsRequest, - response *QueryMissingEventsResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryStateAndAuthChain implements RoomserverQueryAPI -func (h *httpRoomserverQueryAPI) QueryStateAndAuthChain( - ctx context.Context, - request *QueryStateAndAuthChainRequest, - response *QueryStateAndAuthChainResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryBackfill implements RoomServerQueryAPI -func (h *httpRoomserverQueryAPI) QueryBackfill( - ctx context.Context, - request *QueryBackfillRequest, - response *QueryBackfillResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBackfill") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryBackfillPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryServersInRoomAtEvent implements RoomServerQueryAPI -func (h *httpRoomserverQueryAPI) QueryServersInRoomAtEvent( - ctx context.Context, - request *QueryServersInRoomAtEventRequest, - response *QueryServersInRoomAtEventResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServersInRoomAtEvent") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryServersInRoomAtEventPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryRoomVersionCapabilities implements RoomServerQueryAPI -func (h *httpRoomserverQueryAPI) QueryRoomVersionCapabilities( - ctx context.Context, - request *QueryRoomVersionCapabilitiesRequest, - response *QueryRoomVersionCapabilitiesResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionCapabilities") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryRoomVersionForRoom implements RoomServerQueryAPI -func (h *httpRoomserverQueryAPI) QueryRoomVersionForRoom( - ctx context.Context, - request *QueryRoomVersionForRoomRequest, - response *QueryRoomVersionForRoomResponse, -) error { - if roomVersion, ok := h.immutableCache.GetRoomVersion(request.RoomID); ok { - response.RoomVersion = roomVersion - return nil - } - - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionForRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath - err := commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err == nil { - h.immutableCache.StoreRoomVersion(request.RoomID, response.RoomVersion) - } - return err -} diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go new file mode 100644 index 000000000..b73cd1902 --- /dev/null +++ b/roomserver/api/wrapper.go @@ -0,0 +1,117 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +// SendEvents to the roomserver The events are written with KindNew. +func SendEvents( + ctx context.Context, rsAPI RoomserverInternalAPI, events []gomatrixserverlib.HeaderedEvent, + sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, +) (string, error) { + ires := make([]InputRoomEvent, len(events)) + for i, event := range events { + ires[i] = InputRoomEvent{ + Kind: KindNew, + Event: event, + AuthEventIDs: event.AuthEventIDs(), + SendAsServer: string(sendAsServer), + TransactionID: txnID, + } + } + return SendInputRoomEvents(ctx, rsAPI, ires) +} + +// SendEventWithState writes an event with KindNew to the roomserver +// with the state at the event as KindOutlier before it. Will not send any event that is +// marked as `true` in haveEventIDs +func SendEventWithState( + ctx context.Context, rsAPI RoomserverInternalAPI, state *gomatrixserverlib.RespState, + event gomatrixserverlib.HeaderedEvent, haveEventIDs map[string]bool, +) error { + outliers, err := state.Events() + if err != nil { + return err + } + + var ires []InputRoomEvent + for _, outlier := range outliers { + if haveEventIDs[outlier.EventID()] { + continue + } + ires = append(ires, InputRoomEvent{ + Kind: KindOutlier, + Event: outlier.Headered(event.RoomVersion), + AuthEventIDs: outlier.AuthEventIDs(), + }) + } + + stateEventIDs := make([]string, len(state.StateEvents)) + for i := range state.StateEvents { + stateEventIDs[i] = state.StateEvents[i].EventID() + } + + ires = append(ires, InputRoomEvent{ + Kind: KindNew, + Event: event, + AuthEventIDs: event.AuthEventIDs(), + HasState: true, + StateEventIDs: stateEventIDs, + }) + + _, err = SendInputRoomEvents(ctx, rsAPI, ires) + return err +} + +// SendInputRoomEvents to the roomserver. +func SendInputRoomEvents( + ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent, +) (eventID string, err error) { + request := InputRoomEventsRequest{InputRoomEvents: ires} + var response InputRoomEventsResponse + err = rsAPI.InputRoomEvents(ctx, &request, &response) + eventID = response.EventID + return +} + +// SendInvite event to the roomserver. +// This should only be needed for invite events that occur outside of a known room. +// If we are in the room then the event should be sent using the SendEvents method. +func SendInvite( + ctx context.Context, + rsAPI RoomserverInternalAPI, inviteEvent gomatrixserverlib.HeaderedEvent, + inviteRoomState []gomatrixserverlib.InviteV2StrippedState, + sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, +) *PerformError { + request := PerformInviteRequest{ + Event: inviteEvent, + InviteRoomState: inviteRoomState, + RoomVersion: inviteEvent.RoomVersion, + SendAsServer: string(sendAsServer), + TransactionID: txnID, + } + var response PerformInviteResponse + rsAPI.PerformInvite(ctx, &request, &response) + // we need to do this because many places people will use `var err error` as the return + // arg and a nil interface != nil pointer to a concrete interface (in this case PerformError) + if response.Error != nil && response.Error.Msg != "" { + return response.Error + } + return nil +} diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index 615a94b3c..fdcf9f062 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -27,7 +27,7 @@ func IsServerAllowed( serverCurrentlyInRoom bool, authEvents []gomatrixserverlib.Event, ) bool { - historyVisibility := historyVisibilityForRoom(authEvents) + historyVisibility := HistoryVisibilityForRoom(authEvents) // 1. If the history_visibility was set to world_readable, allow. if historyVisibility == "world_readable" { @@ -52,7 +52,7 @@ func IsServerAllowed( return false } -func historyVisibilityForRoom(authEvents []gomatrixserverlib.Event) string { +func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.Event) string { // https://matrix.org/docs/spec/client_server/r0.6.0#id87 // By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared. visibility := "shared" diff --git a/roomserver/input/events.go b/roomserver/input/events.go deleted file mode 100644 index 2bb0d0a05..000000000 --- a/roomserver/input/events.go +++ /dev/null @@ -1,265 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package input - -import ( - "context" - "fmt" - - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/state/database" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" -) - -// A RoomEventDatabase has the storage APIs needed to store a room event. -type RoomEventDatabase interface { - database.RoomStateDatabase - // Stores a matrix room event in the database - StoreEvent( - ctx context.Context, - event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, - authEventNIDs []types.EventNID, - ) (types.RoomNID, types.StateAtEvent, error) - // Look up the state entries for a list of string event IDs - // Returns an error if the there is an error talking to the database - // Returns a types.MissingEventError if the event IDs aren't in the database. - StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, - ) ([]types.StateEntry, error) - // Set the state at an event. - SetState( - ctx context.Context, - eventNID types.EventNID, - stateNID types.StateSnapshotNID, - ) error - // Look up the latest events in a room in preparation for an update. - // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. - // Returns the latest events in the room and the last eventID sent to the log along with an updater. - // If this returns an error then no further action is required. - GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, - ) (updater types.RoomRecentEventsUpdater, err error) - // Look up the string event IDs for a list of numeric event IDs - EventIDs( - ctx context.Context, eventNIDs []types.EventNID, - ) (map[types.EventNID]string, error) - // Build a membership updater for the target user in a room. - MembershipUpdater( - ctx context.Context, roomID, targerUserID string, - roomVersion gomatrixserverlib.RoomVersion, - ) (types.MembershipUpdater, error) - // Look up event ID by transaction's info. - // This is used to determine if the room event is processed/processing already. - // Returns an empty string if no such event exists. - GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, - ) (string, error) - // Look up the room version for a given room. - GetRoomVersionForRoom( - ctx context.Context, roomID string, - ) (gomatrixserverlib.RoomVersion, error) -} - -// OutputRoomEventWriter has the APIs needed to write an event to the output logs. -type OutputRoomEventWriter interface { - // Write a list of events for a room - WriteOutputEvents(roomID string, updates []api.OutputEvent) error -} - -// processRoomEvent can only be called once at a time -// -// TODO(#375): This should be rewritten to allow concurrent calls. The -// difficulty is in ensuring that we correctly annotate events with the correct -// state deltas when sending to kafka streams -func processRoomEvent( - ctx context.Context, - db RoomEventDatabase, - ow OutputRoomEventWriter, - input api.InputRoomEvent, -) (eventID string, err error) { - // Parse and validate the event JSON - headered := input.Event - event := headered.Unwrap() - - // Check that the event passes authentication checks and work out the numeric IDs for the auth events. - authEventNIDs, err := checkAuthEvents(ctx, db, headered, input.AuthEventIDs) - if err != nil { - return - } - - if input.TransactionID != nil { - tdID := input.TransactionID - eventID, err = db.GetTransactionEventID( - ctx, tdID.TransactionID, tdID.SessionID, event.Sender(), - ) - // On error OR event with the transaction already processed/processesing - if err != nil || eventID != "" { - return - } - } - - // Store the event - roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) - if err != nil { - return - } - - if input.Kind == api.KindOutlier { - // For outliers we can stop after we've stored the event itself as it - // doesn't have any associated state to store and we don't need to - // notify anyone about it. - return event.EventID(), nil - } - - if stateAtEvent.BeforeStateSnapshotNID == 0 { - // We haven't calculated a state for this event yet. - // Lets calculate one. - err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event) - if err != nil { - return - } - } - - if input.Kind == api.KindBackfill { - // Backfill is not implemented. - panic("Not implemented") - } - - // Update the extremities of the event graph for the room - return event.EventID(), updateLatestEvents( - ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer, input.TransactionID, - ) -} - -func calculateAndSetState( - ctx context.Context, - db RoomEventDatabase, - input api.InputRoomEvent, - roomNID types.RoomNID, - stateAtEvent *types.StateAtEvent, - event gomatrixserverlib.Event, -) error { - var err error - roomState := state.NewStateResolution(db) - - if input.HasState { - // We've been told what the state at the event is so we don't need to calculate it. - // Check that those state events are in the database and store the state. - var entries []types.StateEntry - if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { - return err - } - - if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil { - return err - } - } else { - // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil { - return err - } - } - return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) -} - -func processInviteEvent( - ctx context.Context, - db RoomEventDatabase, - ow OutputRoomEventWriter, - input api.InputInviteEvent, -) (err error) { - if input.Event.StateKey() == nil { - return fmt.Errorf("invite must be a state event") - } - - roomID := input.Event.RoomID() - targetUserID := *input.Event.StateKey() - - log.WithFields(log.Fields{ - "event_id": input.Event.EventID(), - "room_id": roomID, - "room_version": input.RoomVersion, - "target_user_id": targetUserID, - }).Info("processing invite event") - - updater, err := db.MembershipUpdater(ctx, roomID, targetUserID, input.RoomVersion) - if err != nil { - return err - } - succeeded := false - defer func() { - txerr := common.EndTransaction(updater, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() - - if updater.IsJoin() { - // If the user is joined to the room then that takes precedence over this - // invite event. It makes little sense to move a user that is already - // joined to the room into the invite state. - // This could plausibly happen if an invite request raced with a join - // request for a user. For example if a user was invited to a public - // room and they joined the room at the same time as the invite was sent. - // The other way this could plausibly happen is if an invite raced with - // a kick. For example if a user was kicked from a room in error and in - // response someone else in the room re-invited them then it is possible - // for the invite request to race with the leave event so that the - // target receives invite before it learns that it has been kicked. - // There are a few ways this could be plausibly handled in the roomserver. - // 1) Store the invite, but mark it as retired. That will result in the - // permanent rejection of that invite event. So even if the target - // user leaves the room and the invite is retransmitted it will be - // ignored. However a new invite with a new event ID would still be - // accepted. - // 2) Silently discard the invite event. This means that if the event - // was retransmitted at a later date after the target user had left - // the room we would accept the invite. However since we hadn't told - // the sending server that the invite had been discarded it would - // have no reason to attempt to retry. - // 3) Signal the sending server that the user is already joined to the - // room. - // For now we will implement option 2. Since in the abesence of a retry - // mechanism it will be equivalent to option 1, and we don't have a - // signalling mechanism to implement option 3. - return nil - } - - event := input.Event.Unwrap() - - if err = event.SetUnsignedField("invite_room_state", input.InviteRoomState); err != nil { - return err - } - - outputUpdates, err := updateToInviteMembership(updater, &event, nil, input.Event.RoomVersion) - if err != nil { - return err - } - - if err = ow.WriteOutputEvents(roomID, outputUpdates); err != nil { - return err - } - - succeeded = true - return nil -} diff --git a/roomserver/input/input.go b/roomserver/input/input.go deleted file mode 100644 index bd029d8df..000000000 --- a/roomserver/input/input.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package input contains the code processes new room events -package input - -import ( - "context" - "encoding/json" - "net/http" - "sync" - - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/util" - sarama "gopkg.in/Shopify/sarama.v1" -) - -// RoomserverInputAPI implements api.RoomserverInputAPI -type RoomserverInputAPI struct { - DB RoomEventDatabase - Producer sarama.SyncProducer - // The kafkaesque topic to output new room events to. - // This is the name used in kafka to identify the stream to write events to. - OutputRoomEventTopic string - // Protects calls to processRoomEvent - mutex sync.Mutex -} - -// WriteOutputEvents implements OutputRoomEventWriter -func (r *RoomserverInputAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error { - messages := make([]*sarama.ProducerMessage, len(updates)) - for i := range updates { - value, err := json.Marshal(updates[i]) - if err != nil { - return err - } - messages[i] = &sarama.ProducerMessage{ - Topic: r.OutputRoomEventTopic, - Key: sarama.StringEncoder(roomID), - Value: sarama.ByteEncoder(value), - } - } - return r.Producer.SendMessages(messages) -} - -// InputRoomEvents implements api.RoomserverInputAPI -func (r *RoomserverInputAPI) InputRoomEvents( - ctx context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) (err error) { - // We lock as processRoomEvent can only be called once at a time - r.mutex.Lock() - defer r.mutex.Unlock() - for i := range request.InputRoomEvents { - if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil { - return err - } - } - for i := range request.InputInviteEvents { - if err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil { - return err - } - } - return nil -} - -// SetupHTTP adds the RoomserverInputAPI handlers to the http.ServeMux. -func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle(api.RoomserverInputRoomEventsPath, - common.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse { - var request api.InputRoomEventsRequest - var response api.InputRoomEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.InputRoomEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) -} diff --git a/roomserver/alias/alias.go b/roomserver/internal/alias.go similarity index 50% rename from roomserver/alias/alias.go rename to roomserver/internal/alias.go index eb606e5cd..4139582b6 100644 --- a/roomserver/alias/alias.go +++ b/roomserver/internal/alias.go @@ -12,25 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package alias +package internal import ( "context" "encoding/json" "errors" - "net/http" "time" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" - roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) -// RoomserverAliasAPIDatabase has the storage APIs needed to implement the alias API. -type RoomserverAliasAPIDatabase interface { +// RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API. +type RoomserverInternalAPIDatabase interface { // Save a given room alias with the room ID it refers to. // Returns an error if there was a problem talking to the database. SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error @@ -52,20 +47,11 @@ type RoomserverAliasAPIDatabase interface { ) (gomatrixserverlib.RoomVersion, error) } -// RoomserverAliasAPI is an implementation of alias.RoomserverAliasAPI -type RoomserverAliasAPI struct { - DB RoomserverAliasAPIDatabase - Cfg *config.Dendrite - InputAPI roomserverAPI.RoomserverInputAPI - QueryAPI roomserverAPI.RoomserverQueryAPI - AppserviceAPI appserviceAPI.AppServiceQueryAPI -} - -// SetRoomAlias implements alias.RoomserverAliasAPI -func (r *RoomserverAliasAPI) SetRoomAlias( +// SetRoomAlias implements alias.RoomserverInternalAPI +func (r *RoomserverInternalAPI) SetRoomAlias( ctx context.Context, - request *roomserverAPI.SetRoomAliasRequest, - response *roomserverAPI.SetRoomAliasResponse, + request *api.SetRoomAliasRequest, + response *api.SetRoomAliasResponse, ) error { // Check if the alias isn't already referring to a room roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) @@ -91,11 +77,11 @@ func (r *RoomserverAliasAPI) SetRoomAlias( return r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, request.RoomID) } -// GetRoomIDForAlias implements alias.RoomserverAliasAPI -func (r *RoomserverAliasAPI) GetRoomIDForAlias( +// GetRoomIDForAlias implements alias.RoomserverInternalAPI +func (r *RoomserverInternalAPI) GetRoomIDForAlias( ctx context.Context, - request *roomserverAPI.GetRoomIDForAliasRequest, - response *roomserverAPI.GetRoomIDForAliasResponse, + request *api.GetRoomIDForAliasRequest, + response *api.GetRoomIDForAliasResponse, ) error { // Look up the room ID in the database roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) @@ -103,32 +89,38 @@ func (r *RoomserverAliasAPI) GetRoomIDForAlias( return err } - if roomID == "" { - // No room found locally, try our application services by making a call to - // the appservice component - aliasReq := appserviceAPI.RoomAliasExistsRequest{Alias: request.Alias} - var aliasResp appserviceAPI.RoomAliasExistsResponse - if err = r.AppserviceAPI.RoomAliasExists(ctx, &aliasReq, &aliasResp); err != nil { - return err - } + /* + TODO: Why is this here? It creates an unnecessary dependency + from the roomserver to the appservice component, which should be + altogether optional. - if aliasResp.AliasExists { - roomID, err = r.DB.GetRoomIDForAlias(ctx, request.Alias) - if err != nil { + if roomID == "" { + // No room found locally, try our application services by making a call to + // the appservice component + aliasReq := appserviceAPI.RoomAliasExistsRequest{Alias: request.Alias} + var aliasResp appserviceAPI.RoomAliasExistsResponse + if err = r.AppserviceAPI.RoomAliasExists(ctx, &aliasReq, &aliasResp); err != nil { return err } + + if aliasResp.AliasExists { + roomID, err = r.DB.GetRoomIDForAlias(ctx, request.Alias) + if err != nil { + return err + } + } } - } + */ response.RoomID = roomID return nil } -// GetAliasesForRoomID implements alias.RoomserverAliasAPI -func (r *RoomserverAliasAPI) GetAliasesForRoomID( +// GetAliasesForRoomID implements alias.RoomserverInternalAPI +func (r *RoomserverInternalAPI) GetAliasesForRoomID( ctx context.Context, - request *roomserverAPI.GetAliasesForRoomIDRequest, - response *roomserverAPI.GetAliasesForRoomIDResponse, + request *api.GetAliasesForRoomIDRequest, + response *api.GetAliasesForRoomIDResponse, ) error { // Look up the aliases in the database for the given RoomID aliases, err := r.DB.GetAliasesForRoomID(ctx, request.RoomID) @@ -140,11 +132,11 @@ func (r *RoomserverAliasAPI) GetAliasesForRoomID( return nil } -// GetCreatorIDForAlias implements alias.RoomserverAliasAPI -func (r *RoomserverAliasAPI) GetCreatorIDForAlias( +// GetCreatorIDForAlias implements alias.RoomserverInternalAPI +func (r *RoomserverInternalAPI) GetCreatorIDForAlias( ctx context.Context, - request *roomserverAPI.GetCreatorIDForAliasRequest, - response *roomserverAPI.GetCreatorIDForAliasResponse, + request *api.GetCreatorIDForAliasRequest, + response *api.GetCreatorIDForAliasResponse, ) error { // Look up the aliases in the database for the given RoomID creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias) @@ -156,11 +148,11 @@ func (r *RoomserverAliasAPI) GetCreatorIDForAlias( return nil } -// RemoveRoomAlias implements alias.RoomserverAliasAPI -func (r *RoomserverAliasAPI) RemoveRoomAlias( +// RemoveRoomAlias implements alias.RoomserverInternalAPI +func (r *RoomserverInternalAPI) RemoveRoomAlias( ctx context.Context, - request *roomserverAPI.RemoveRoomAliasRequest, - response *roomserverAPI.RemoveRoomAliasResponse, + request *api.RemoveRoomAliasRequest, + response *api.RemoveRoomAliasResponse, ) error { // Look up the room ID in the database roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) @@ -186,7 +178,7 @@ type roomAliasesContent struct { // Build the updated m.room.aliases event to send to the room after addition or // removal of an alias -func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent( +func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( ctx context.Context, userID string, roomID string, ) error { serverName := string(r.Cfg.Matrix.ServerName) @@ -222,12 +214,12 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent( if len(eventsNeeded.Tuples()) == 0 { return errors.New("expecting state tuples for event builder, got none") } - req := roomserverAPI.QueryLatestEventsAndStateRequest{ + req := api.QueryLatestEventsAndStateRequest{ RoomID: roomID, StateToFetch: eventsNeeded.Tuples(), } - var res roomserverAPI.QueryLatestEventsAndStateResponse - if err = r.QueryAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { + var res api.QueryLatestEventsAndStateResponse + if err = r.QueryLatestEventsAndState(ctx, &req, &res); err != nil { return err } builder.Depth = res.Depth @@ -263,91 +255,17 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent( } // Create the request - ire := roomserverAPI.InputRoomEvent{ - Kind: roomserverAPI.KindNew, + ire := api.InputRoomEvent{ + Kind: api.KindNew, Event: event.Headered(roomVersion), AuthEventIDs: event.AuthEventIDs(), SendAsServer: serverName, } - inputReq := roomserverAPI.InputRoomEventsRequest{ - InputRoomEvents: []roomserverAPI.InputRoomEvent{ire}, + inputReq := api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{ire}, } - var inputRes roomserverAPI.InputRoomEventsResponse + var inputRes api.InputRoomEventsResponse // Send the request - return r.InputAPI.InputRoomEvents(ctx, &inputReq, &inputRes) -} - -// SetupHTTP adds the RoomserverAliasAPI handlers to the http.ServeMux. -func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle( - roomserverAPI.RoomserverSetRoomAliasPath, - common.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse { - var request roomserverAPI.SetRoomAliasRequest - var response roomserverAPI.SetRoomAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - roomserverAPI.RoomserverGetRoomIDForAliasPath, - common.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse { - var request roomserverAPI.GetRoomIDForAliasRequest - var response roomserverAPI.GetRoomIDForAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - roomserverAPI.RoomserverGetCreatorIDForAliasPath, - common.MakeInternalAPI("GetCreatorIDForAlias", func(req *http.Request) util.JSONResponse { - var request roomserverAPI.GetCreatorIDForAliasRequest - var response roomserverAPI.GetCreatorIDForAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.GetCreatorIDForAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - roomserverAPI.RoomserverGetAliasesForRoomIDPath, - common.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { - var request roomserverAPI.GetAliasesForRoomIDRequest - var response roomserverAPI.GetAliasesForRoomIDResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - roomserverAPI.RoomserverRemoveRoomAliasPath, - common.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { - var request roomserverAPI.RemoveRoomAliasRequest - var response roomserverAPI.RemoveRoomAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) + return r.InputRoomEvents(ctx, &inputReq, &inputRes) } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go new file mode 100644 index 000000000..37a8a39bf --- /dev/null +++ b/roomserver/internal/api.go @@ -0,0 +1,26 @@ +package internal + +import ( + "sync" + + "github.com/Shopify/sarama" + fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/gomatrixserverlib" +) + +// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI +type RoomserverInternalAPI struct { + DB storage.Database + Cfg *config.Dendrite + Producer sarama.SyncProducer + Cache caching.RoomVersionCache + ServerName gomatrixserverlib.ServerName + KeyRing gomatrixserverlib.JSONVerifier + FedClient *gomatrixserverlib.FederationClient + OutputRoomEventTopic string // Kafka topic for new output room events + mutex sync.Mutex // Protects calls to processRoomEvent + fsAPI fsAPI.FederationSenderInternalAPI +} diff --git a/roomserver/internal/input.go b/roomserver/internal/input.go new file mode 100644 index 000000000..2af3e62d8 --- /dev/null +++ b/roomserver/internal/input.go @@ -0,0 +1,83 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package input contains the code processes new room events +package internal + +import ( + "context" + "encoding/json" + + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/roomserver/api" + log "github.com/sirupsen/logrus" + + fsAPI "github.com/matrix-org/dendrite/federationsender/api" +) + +// SetFederationSenderInputAPI passes in a federation sender input API reference +// so that we can avoid the chicken-and-egg problem of both the roomserver input API +// and the federation sender input API being interdependent. +func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) { + r.fsAPI = fsAPI +} + +// WriteOutputEvents implements OutputRoomEventWriter +func (r *RoomserverInternalAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error { + messages := make([]*sarama.ProducerMessage, len(updates)) + for i := range updates { + value, err := json.Marshal(updates[i]) + if err != nil { + return err + } + logger := log.WithFields(log.Fields{ + "room_id": roomID, + "type": updates[i].Type, + }) + if updates[i].NewRoomEvent != nil { + logger = logger.WithFields(log.Fields{ + "event_type": updates[i].NewRoomEvent.Event.Type(), + "event_id": updates[i].NewRoomEvent.Event.EventID(), + "adds_state": len(updates[i].NewRoomEvent.AddsStateEventIDs), + "removes_state": len(updates[i].NewRoomEvent.RemovesStateEventIDs), + "send_as_server": updates[i].NewRoomEvent.SendAsServer, + "sender": updates[i].NewRoomEvent.Event.Sender(), + }) + } + logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic) + messages[i] = &sarama.ProducerMessage{ + Topic: r.OutputRoomEventTopic, + Key: sarama.StringEncoder(roomID), + Value: sarama.ByteEncoder(value), + } + } + return r.Producer.SendMessages(messages) +} + +// InputRoomEvents implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) InputRoomEvents( + ctx context.Context, + request *api.InputRoomEventsRequest, + response *api.InputRoomEventsResponse, +) (err error) { + // We lock as processRoomEvent can only be called once at a time + r.mutex.Lock() + defer r.mutex.Unlock() + for i := range request.InputRoomEvents { + if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil { + return err + } + } + return nil +} diff --git a/roomserver/input/authevents.go b/roomserver/internal/input_authevents.go similarity index 98% rename from roomserver/input/authevents.go rename to roomserver/internal/input_authevents.go index 456a01c79..e3828f566 100644 --- a/roomserver/input/authevents.go +++ b/roomserver/internal/input_authevents.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package input +package internal import ( "context" "sort" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -26,7 +27,7 @@ import ( // Returns the numeric IDs for the auth events. func checkAuthEvents( ctx context.Context, - db RoomEventDatabase, + db storage.Database, event gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -127,7 +128,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * // loadAuthEvents loads the events needed for authentication from the supplied room state. func loadAuthEvents( ctx context.Context, - db RoomEventDatabase, + db storage.Database, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { diff --git a/roomserver/input/authevents_test.go b/roomserver/internal/input_authevents_test.go similarity index 99% rename from roomserver/input/authevents_test.go rename to roomserver/internal/input_authevents_test.go index 0621a0842..6b981571b 100644 --- a/roomserver/input/authevents_test.go +++ b/roomserver/internal/input_authevents_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package input +package internal import ( "testing" diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go new file mode 100644 index 000000000..ae57f2e77 --- /dev/null +++ b/roomserver/internal/input_events.go @@ -0,0 +1,145 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// processRoomEvent can only be called once at a time +// +// TODO(#375): This should be rewritten to allow concurrent calls. The +// difficulty is in ensuring that we correctly annotate events with the correct +// state deltas when sending to kafka streams +func (r *RoomserverInternalAPI) processRoomEvent( + ctx context.Context, + input api.InputRoomEvent, +) (eventID string, err error) { + // Parse and validate the event JSON + headered := input.Event + event := headered.Unwrap() + + // Check that the event passes authentication checks and work out + // the numeric IDs for the auth events. + authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) + if err != nil { + logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") + return + } + + // If we don't have a transaction ID then get one. + if input.TransactionID != nil { + tdID := input.TransactionID + eventID, err = r.DB.GetTransactionEventID( + ctx, tdID.TransactionID, tdID.SessionID, event.Sender(), + ) + // On error OR event with the transaction already processed/processesing + if err != nil || eventID != "" { + return + } + } + + // Store the event. + roomNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + if err != nil { + return + } + + // For outliers we can stop after we've stored the event itself as it + // doesn't have any associated state to store and we don't need to + // notify anyone about it. + if input.Kind == api.KindOutlier { + logrus.WithFields(logrus.Fields{ + "event_id": event.EventID(), + "type": event.Type(), + "room": event.RoomID(), + }).Info("Stored outlier") + return event.EventID(), nil + } + + if stateAtEvent.BeforeStateSnapshotNID == 0 { + // We haven't calculated a state for this event yet. + // Lets calculate one. + err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event) + if err != nil { + return + } + } + + if err = r.updateLatestEvents( + ctx, // context + roomNID, // room NID to update + stateAtEvent, // state at event (below) + event, // event + input.SendAsServer, // send as server + input.TransactionID, // transaction ID + ); err != nil { + return + } + + // Update the extremities of the event graph for the room + return event.EventID(), nil +} + +func (r *RoomserverInternalAPI) calculateAndSetState( + ctx context.Context, + input api.InputRoomEvent, + roomNID types.RoomNID, + stateAtEvent *types.StateAtEvent, + event gomatrixserverlib.Event, +) error { + var err error + roomState := state.NewStateResolution(r.DB) + + if input.HasState { + // Check here if we think we're in the room already. + stateAtEvent.Overwrite = true + var joinEventNIDs []types.EventNID + // Request join memberships only for local users only. + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil { + // If we have no local users that are joined to the room then any state about + // the room that we have is quite possibly out of date. Therefore in that case + // we should overwrite it rather than merge it. + stateAtEvent.Overwrite = len(joinEventNIDs) == 0 + } + + // We've been told what the state at the event is so we don't need to calculate it. + // Check that those state events are in the database and store the state. + var entries []types.StateEntry + if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + return err + } + + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + return err + } + } else { + stateAtEvent.Overwrite = false + + // We haven't been told what the state at the event is so we need to calculate it from the prev_events + if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil { + return err + } + } + return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) +} diff --git a/roomserver/input/latest_events.go b/roomserver/internal/input_latest_events.go similarity index 65% rename from roomserver/input/latest_events.go rename to roomserver/internal/input_latest_events.go index 525a6f518..66316ac4f 100644 --- a/roomserver/input/latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -14,13 +14,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package input +package internal import ( "bytes" "context" + "fmt" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" @@ -45,33 +46,37 @@ import ( // 7 <----- latest // // Can only be called once at a time -func updateLatestEvents( +func (r *RoomserverInternalAPI) updateLatestEvents( ctx context.Context, - db RoomEventDatabase, - ow OutputRoomEventWriter, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, ) (err error) { - updater, err := db.GetLatestEventsForUpdate(ctx, roomNID) + updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) if err != nil { return } succeeded := false defer func() { - txerr := common.EndTransaction(updater, &succeeded) + txerr := sqlutil.EndTransaction(updater, &succeeded) if err == nil && txerr != nil { err = txerr } }() u := latestEventsUpdater{ - ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID, - stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer, + ctx: ctx, + api: r, + updater: updater, + roomNID: roomNID, + stateAtEvent: stateAtEvent, + event: event, + sendAsServer: sendAsServer, transactionID: transactionID, } + if err = u.doUpdateLatestEvents(); err != nil { return err } @@ -86,9 +91,8 @@ func updateLatestEvents( // when there are so many variables to pass around. type latestEventsUpdater struct { ctx context.Context - db RoomEventDatabase + api *RoomserverInternalAPI updater types.RoomRecentEventsUpdater - ow OutputRoomEventWriter roomNID types.RoomNID stateAtEvent types.StateAtEvent event gomatrixserverlib.Event @@ -114,39 +118,66 @@ type latestEventsUpdater struct { func (u *latestEventsUpdater) doUpdateLatestEvents() error { prevEvents := u.event.PrevEvents() - oldLatest := u.updater.LatestEvents() u.lastEventIDSent = u.updater.LastEventIDSent() u.oldStateNID = u.updater.CurrentStateSnapshotNID() + // If we are doing a regular event update then we will get the + // previous latest events to use as a part of the calculation. If + // we are overwriting the latest events because we have a complete + // state snapshot from somewhere else, e.g. a federated room join, + // then start with an empty set - none of the forward extremities + // that we knew about before matter anymore. + oldLatest := []types.StateAtEventAndReference{} + if !u.stateAtEvent.Overwrite { + oldLatest = u.updater.LatestEvents() + } + + // If the event has already been written to the output log then we + // don't need to do anything, as we've handled it already. hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID) if err != nil { return err } else if hasBeenSent { - // Already sent this event so we can stop processing return nil } + // Update the roomserver_previous_events table with references. This + // is effectively tracking the structure of the DAG. if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil { return err } + // Get the event reference for our new event. This will be used when + // determining if the event is referenced by an existing event. eventReference := u.event.EventReference() - // Check if this event is already referenced by another event in the room. + + // Check if our new event is already referenced by an existing event + // in the room. If it is then it isn't a latest event. alreadyReferenced, err := u.updater.IsReferenced(eventReference) if err != nil { return err } - u.latest = calculateLatest(oldLatest, alreadyReferenced, prevEvents, types.StateAtEventAndReference{ - EventReference: eventReference, - StateAtEvent: u.stateAtEvent, - }) + // Work out what the latest events are. + u.latest = calculateLatest( + oldLatest, + alreadyReferenced, + prevEvents, + types.StateAtEventAndReference{ + EventReference: eventReference, + StateAtEvent: u.stateAtEvent, + }, + ) + // Now that we know what the latest events are, it's time to get the + // latest state. if err = u.latestState(); err != nil { return err } - updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added) + // If we need to generate any output events then here's where we do it. + // TODO: Move this! + updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added) if err != nil { return err } @@ -165,7 +196,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // send the event asynchronously but we would need to ensure that 1) the events are written to the log in // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // necessary bookkeeping we'll keep the event sending synchronous for now. - if err = u.ow.WriteOutputEvents(u.event.RoomID(), updates); err != nil { + if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { return err } @@ -178,12 +209,17 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.db) + roomState := state.NewStateResolution(u.api.DB) + // Get a list of the current latest events. latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) for i := range u.latest { latestStateAtEvents[i] = u.latest[i].StateAtEvent } + + // Takes the NIDs of the latest events and creates a state snapshot + // of the state after the events. The snapshot state will be resolved + // using the correct state resolution algorithm for the room. u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( u.ctx, u.roomNID, latestStateAtEvents, ) @@ -191,6 +227,18 @@ func (u *latestEventsUpdater) latestState() error { return err } + // If we are overwriting the state then we should make sure that we + // don't send anything out over federation again, it will very likely + // be a repeat. + if u.stateAtEvent.Overwrite { + u.sendAsServer = "" + } + + // Now that we have a new state snapshot based on the latest events, + // we can compare that new snapshot to the previous one and see what + // has changed. This gives us one list of removed state events and + // another list of added ones. Replacing a value for a state-key tuple + // will result one removed (the old event) and one added (the new event). u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots( u.ctx, u.oldStateNID, u.newStateNID, ) @@ -198,6 +246,8 @@ func (u *latestEventsUpdater) latestState() error { return err } + // Also work out the state before the event removes and the event + // adds. u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots( u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, ) @@ -249,7 +299,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) latestEventIDs[i] = u.latest[i].EventID } - roomVersion, err := u.db.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) + roomVersion, err := u.api.DB.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) if err != nil { return nil, err } @@ -261,24 +311,11 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) TransactionID: u.transactionID, } - var stateEventNIDs []types.EventNID - for _, entry := range u.added { - stateEventNIDs = append(stateEventNIDs, entry.EventNID) - } - for _, entry := range u.removed { - stateEventNIDs = append(stateEventNIDs, entry.EventNID) - } - for _, entry := range u.stateBeforeEventRemoves { - stateEventNIDs = append(stateEventNIDs, entry.EventNID) - } - for _, entry := range u.stateBeforeEventAdds { - stateEventNIDs = append(stateEventNIDs, entry.EventNID) - } - stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] - eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs) + eventIDMap, err := u.stateEventMap() if err != nil { return nil, err } + for _, entry := range u.added { ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID]) } @@ -293,12 +330,60 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) } ore.SendAsServer = u.sendAsServer + // include extra state events if they were added as nearly every downstream component will care about it + // and we'd rather not have them all hit QueryEventsByID at the same time! + if len(ore.AddsStateEventIDs) > 0 { + ore.AddStateEvents, err = u.extraEventsForIDs(roomVersion, ore.AddsStateEventIDs) + if err != nil { + return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) + } + } + return &api.OutputEvent{ Type: api.OutputTypeNewRoomEvent, NewRoomEvent: &ore, }, nil } +// extraEventsForIDs returns the full events for the event IDs given, but does not include the current event being +// updated. +func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { + var extraEventIDs []string + for _, e := range eventIDs { + if e == u.event.EventID() { + continue + } + extraEventIDs = append(extraEventIDs, e) + } + if len(extraEventIDs) == 0 { + return nil, nil + } + extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs) + if err != nil { + return nil, err + } + var h []gomatrixserverlib.HeaderedEvent + for _, e := range extraEvents { + h = append(h, e.Headered(roomVersion)) + } + return h, nil +} + +// retrieve an event nid -> event ID map for all events that need updating +func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) { + var stateEventNIDs []types.EventNID + var allStateEntries []types.StateEntry + allStateEntries = append(allStateEntries, u.added...) + allStateEntries = append(allStateEntries, u.removed...) + allStateEntries = append(allStateEntries, u.stateBeforeEventRemoves...) + allStateEntries = append(allStateEntries, u.stateBeforeEventAdds...) + for _, entry := range allStateEntries { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] + return u.api.DB.EventIDs(u.ctx, stateEventNIDs) +} + type eventNIDSorter []types.EventNID func (s eventNIDSorter) Len() int { return len(s) } diff --git a/roomserver/input/membership.go b/roomserver/internal/input_membership.go similarity index 74% rename from roomserver/input/membership.go rename to roomserver/internal/input_membership.go index ee39ff5eb..af0c7f8b3 100644 --- a/roomserver/input/membership.go +++ b/roomserver/internal/input_membership.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package input +package internal import ( "context" @@ -27,9 +27,8 @@ import ( // user affected by a change in the current state of the room. // Returns a list of output events to write to the kafka log to inform the // consumers about the invites added or retired by the change in current state. -func updateMemberships( +func (r *RoomserverInternalAPI) updateMemberships( ctx context.Context, - db RoomEventDatabase, updater types.RoomRecentEventsUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { @@ -47,7 +46,7 @@ func updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := db.Events(ctx, eventNIDs) + events, err := r.DB.Events(ctx, eventNIDs) if err != nil { return nil, err } @@ -70,15 +69,16 @@ func updateMemberships( ae = &ev.Event } } - if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil { + if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { return nil, err } } return updates, nil } -func updateMembership( - updater types.RoomRecentEventsUpdater, targetUserNID types.EventStateKeyNID, +func (r *RoomserverInternalAPI) updateMembership( + updater types.RoomRecentEventsUpdater, + targetUserNID types.EventStateKeyNID, remove, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { @@ -105,7 +105,14 @@ func updateMembership( return updates, nil } - mu, err := updater.MembershipUpdater(targetUserNID) + if add == nil { + // This can happen when we have rejoined a room and suddenly we have a + // divergence between the former state and the new one. We don't want to + // act on removals and apparently there are no adds, so stop here. + return updates, nil + } + + mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add)) if err != nil { return nil, err } @@ -124,6 +131,15 @@ func updateMembership( } } +func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bool { + isTargetLocalUser := false + if statekey := event.StateKey(); statekey != nil { + _, domain, _ := gomatrixserverlib.SplitID('@', *statekey) + isTargetLocalUser = domain == r.Cfg.Matrix.ServerName + } + return isTargetLocalUser +} + func updateToInviteMembership( mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, roomVersion gomatrixserverlib.RoomVersion, @@ -143,7 +159,7 @@ func updateToInviteMembership( // consider a single stream of events when determining whether a user // is invited, rather than having to combine multiple streams themselves. onie := api.OutputNewInviteEvent{ - Event: (*add).Headered(roomVersion), + Event: add.Headered(roomVersion), RoomVersion: roomVersion, } updates = append(updates, api.OutputEvent{ @@ -222,8 +238,7 @@ func updateToLeaveMembership( return updates, nil } -// membershipChanges pairs up the membership state changes from a sorted list -// of state removed and a sorted list of state added. +// membershipChanges pairs up the membership state changes. func membershipChanges(removed, added []types.StateEntry) []stateChange { changes := pairUpChanges(removed, added) var result []stateChange @@ -242,64 +257,39 @@ type stateChange struct { } // pairUpChanges pairs up the state events added and removed for each type, -// state key tuple. Assumes that removed and added are sorted. +// state key tuple. func pairUpChanges(removed, added []types.StateEntry) []stateChange { - var ai int - var ri int - var result []stateChange - for { - switch { - case ai == len(added): - // We've reached the end of the added entries. - // The rest of the removed list are events that were removed without - // an event with the same state key being added. - for _, s := range removed[ri:] { - result = append(result, stateChange{ - StateKeyTuple: s.StateKeyTuple, - removedEventNID: s.EventNID, - }) - } - return result - case ri == len(removed): - // We've reached the end of the removed entries. - // The rest of the added list are events that were added without - // an event with the same state key being removed. - for _, s := range added[ai:] { - result = append(result, stateChange{ - StateKeyTuple: s.StateKeyTuple, - addedEventNID: s.EventNID, - }) - } - return result - case added[ai].StateKeyTuple == removed[ri].StateKeyTuple: - // The tuple is in both lists so an event with that key is being - // removed and another event with the same key is being added. - result = append(result, stateChange{ - StateKeyTuple: added[ai].StateKeyTuple, - removedEventNID: removed[ri].EventNID, - addedEventNID: added[ai].EventNID, - }) - ai++ - ri++ - case added[ai].StateKeyTuple.LessThan(removed[ri].StateKeyTuple): - // The lists are sorted so the added entry being less than the - // removed entry means that the added event was added without an - // event with the same key being removed. - result = append(result, stateChange{ - StateKeyTuple: added[ai].StateKeyTuple, - addedEventNID: added[ai].EventNID, - }) - ai++ - default: - // Reaching the default case implies that the removed entry is less - // than the added entry. Since the lists are sorted this means that - // the removed event was removed without an event with the same - // key being added. - result = append(result, stateChange{ - StateKeyTuple: removed[ai].StateKeyTuple, - removedEventNID: removed[ri].EventNID, - }) - ri++ + tuples := make(map[types.StateKeyTuple]stateChange) + changes := []stateChange{} + + // First, go through the newly added state entries. + for _, add := range added { + if change, ok := tuples[add.StateKeyTuple]; ok { + // If we already have an entry, update it. + change.addedEventNID = add.EventNID + tuples[add.StateKeyTuple] = change + } else { + // Otherwise, create a new entry. + tuples[add.StateKeyTuple] = stateChange{add.StateKeyTuple, 0, add.EventNID} } } + + // Now go through the removed state entries. + for _, remove := range removed { + if change, ok := tuples[remove.StateKeyTuple]; ok { + // If we already have an entry, update it. + change.removedEventNID = remove.EventNID + tuples[remove.StateKeyTuple] = change + } else { + // Otherwise, create a new entry. + tuples[remove.StateKeyTuple] = stateChange{remove.StateKeyTuple, remove.EventNID, 0} + } + } + + // Now return the changes as an array. + for _, change := range tuples { + changes = append(changes, change) + } + + return changes } diff --git a/roomserver/internal/perform_backfill.go b/roomserver/internal/perform_backfill.go new file mode 100644 index 000000000..23ae9455a --- /dev/null +++ b/roomserver/internal/perform_backfill.go @@ -0,0 +1,305 @@ +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/auth" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// backfillRequester implements gomatrixserverlib.BackfillRequester +type backfillRequester struct { + db storage.Database + fedClient *gomatrixserverlib.FederationClient + thisServer gomatrixserverlib.ServerName + bwExtrems map[string][]string + + // per-request state + servers []gomatrixserverlib.ServerName + eventIDToBeforeStateIDs map[string][]string + eventIDMap map[string]gomatrixserverlib.Event +} + +func newBackfillRequester(db storage.Database, fedClient *gomatrixserverlib.FederationClient, thisServer gomatrixserverlib.ServerName, bwExtrems map[string][]string) *backfillRequester { + return &backfillRequester{ + db: db, + fedClient: fedClient, + thisServer: thisServer, + eventIDToBeforeStateIDs: make(map[string][]string), + eventIDMap: make(map[string]gomatrixserverlib.Event), + bwExtrems: bwExtrems, + } +} + +func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.HeaderedEvent) ([]string, error) { + b.eventIDMap[targetEvent.EventID()] = targetEvent.Unwrap() + if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { + return ids, nil + } + if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") { + util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID()).Info("Backfilled to the beginning of the room") + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{} + return nil, nil + } + // if we have exactly 1 prev event and we know the state of the room at that prev event, then just roll forward the prev event. + // Else, we have to hit /state_ids because either we don't know the state at all at this event (new backwards extremity) or + // we don't know the result of state res to merge forks (2 or more prev_events) + if len(targetEvent.PrevEventIDs()) == 1 { + prevEventID := targetEvent.PrevEventIDs()[0] + prevEvent, ok := b.eventIDMap[prevEventID] + if !ok { + goto FederationHit + } + prevEventStateIDs, ok := b.eventIDToBeforeStateIDs[prevEventID] + if !ok { + goto FederationHit + } + newStateIDs := b.calculateNewStateIDs(targetEvent.Unwrap(), prevEvent, prevEventStateIDs) + if newStateIDs != nil { + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs + return newStateIDs, nil + } + // else we failed to calculate the new state, so fallthrough + } + +FederationHit: + var lastErr error + logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event") + for _, srv := range b.servers { // hit any valid server + c := gomatrixserverlib.FederatedStateProvider{ + FedClient: b.fedClient, + RememberAuthEvents: false, + Server: srv, + } + res, err := c.StateIDsBeforeEvent(ctx, targetEvent) + if err != nil { + lastErr = err + continue + } + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res + return res, nil + } + return nil, lastErr +} + +func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.Event, prevEventStateIDs []string) []string { + newStateIDs := prevEventStateIDs[:] + if prevEvent.StateKey() == nil { + // state is the same as the previous event + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs + return newStateIDs + } + + missingState := false // true if we are missing the info for a state event ID + foundEvent := false // true if we found a (type, state_key) match + // find which state ID to replace, if any + for i, id := range newStateIDs { + ev, ok := b.eventIDMap[id] + if !ok { + missingState = true + continue + } + // The state IDs BEFORE the target event are the state IDs BEFORE the prev_event PLUS the prev_event itself + if ev.Type() == prevEvent.Type() && ev.StateKey() != nil && *ev.StateKey() == *prevEvent.StateKey() { + newStateIDs[i] = prevEvent.EventID() + foundEvent = true + break + } + } + if !foundEvent && !missingState { + // we can be certain that this is new state + newStateIDs = append(newStateIDs, prevEvent.EventID()) + foundEvent = true + } + + if foundEvent { + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs + return newStateIDs + } + return nil +} + +func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, + event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { + + // try to fetch the events from the database first + events, err := b.ProvideEvents(roomVer, eventIDs) + if err != nil { + // non-fatal, fallthrough + logrus.WithError(err).Info("Failed to fetch events") + } else { + logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs)) + if len(events) == len(eventIDs) { + result := make(map[string]*gomatrixserverlib.Event) + for i := range events { + result[events[i].EventID()] = &events[i] + b.eventIDMap[events[i].EventID()] = events[i] + } + return result, nil + } + } + + c := gomatrixserverlib.FederatedStateProvider{ + FedClient: b.fedClient, + RememberAuthEvents: false, + Server: b.servers[0], + } + result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs) + if err != nil { + return nil, err + } + for eventID, ev := range result { + b.eventIDMap[eventID] = *ev + } + return result, nil +} + +// ServersAtEvent is called when trying to determine which server to request from. +// It returns a list of servers which can be queried for backfill requests. These servers +// will be servers that are in the room already. The entries at the beginning are preferred servers +// and will be tried first. An empty list will fail the request. +func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName { + // eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use + // its successor, so look it up. + successor := "" +FindSuccessor: + for sucID, prevEventIDs := range b.bwExtrems { + for _, pe := range prevEventIDs { + if pe == eventID { + successor = sucID + break FindSuccessor + } + } + } + if successor == "" { + logrus.WithField("event_id", eventID).Error("ServersAtEvent: failed to find successor of this event to determine room state") + return nil + } + eventID = successor + + // getMembershipsBeforeEventNID requires a NID, so retrieving the NID for + // the event is necessary. + NIDs, err := b.db.EventNIDs(ctx, []string{eventID}) + if err != nil { + logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get event NID for event") + return nil + } + + stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID]) + if err != nil { + logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") + return nil + } + + // possibly return all joined servers depending on history visiblity + memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries) + if err != nil { + logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") + return nil + } + logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis)) + + // Retrieve all "m.room.member" state events of "join" membership, which + // contains the list of users in the room before the event, therefore all + // the servers in it at that moment. + memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true) + if err != nil { + logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") + return nil + } + memberEvents = append(memberEvents, memberEventsFromVis...) + + // Store the server names in a temporary map to avoid duplicates. + serverSet := make(map[gomatrixserverlib.ServerName]bool) + for _, event := range memberEvents { + serverSet[event.Origin()] = true + } + var servers []gomatrixserverlib.ServerName + for server := range serverSet { + if server == b.thisServer { + continue + } + servers = append(servers, server) + } + b.servers = servers + return servers +} + +// Backfill performs a backfill request to the given server. +// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid +func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, + fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) { + + tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs) + return &tx, err +} + +func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) { + ctx := context.Background() + nidMap, err := b.db.EventNIDs(ctx, eventIDs) + if err != nil { + logrus.WithError(err).WithField("event_ids", eventIDs).Error("Failed to find events") + return nil, err + } + eventNIDs := make([]types.EventNID, len(nidMap)) + i := 0 + for _, nid := range nidMap { + eventNIDs[i] = nid + i++ + } + eventsWithNids, err := b.db.Events(ctx, eventNIDs) + if err != nil { + logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") + return nil, err + } + events := make([]gomatrixserverlib.Event, len(eventsWithNids)) + for i := range eventsWithNids { + events[i] = eventsWithNids[i].Event + } + return events, nil +} + +// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility. +// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just +// pull all events and then filter by that table. +func joinEventsFromHistoryVisibility( + ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) { + + var eventNIDs []types.EventNID + for _, entry := range stateEntries { + // Filter the events to retrieve to only keep the membership events + if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID { + eventNIDs = append(eventNIDs, entry.EventNID) + break + } + } + + // Get all of the events in this state + stateEvents, err := db.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + events := make([]gomatrixserverlib.Event, len(stateEvents)) + for i := range stateEvents { + events[i] = stateEvents[i].Event + } + visibility := auth.HistoryVisibilityForRoom(events) + if visibility != "shared" { + logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility) + return nil, nil + } + // get joined members + roomNID, err := db.RoomNID(ctx, roomID) + if err != nil { + return nil, err + } + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false) + if err != nil { + return nil, err + } + return db.Events(ctx, joinEventNIDs) +} diff --git a/roomserver/internal/perform_invite.go b/roomserver/internal/perform_invite.go new file mode 100644 index 000000000..c65c87f91 --- /dev/null +++ b/roomserver/internal/perform_invite.go @@ -0,0 +1,249 @@ +package internal + +import ( + "context" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +// PerformInvite handles inviting to matrix rooms, including over federation by talking to the federationsender. +func (r *RoomserverInternalAPI) PerformInvite( + ctx context.Context, + req *api.PerformInviteRequest, + res *api.PerformInviteResponse, +) { + err := r.performInvite(ctx, req) + if err != nil { + perr, ok := err.(*api.PerformError) + if ok { + res.Error = perr + } else { + res.Error = &api.PerformError{ + Msg: err.Error(), + } + } + } +} + +func (r *RoomserverInternalAPI) performInvite(ctx context.Context, + req *api.PerformInviteRequest, +) error { + loopback, err := r.processInviteEvent(ctx, r, req) + if err != nil { + return err + } + // The processInviteEvent function can optionally return a + // loopback room event containing the invite, for local invites. + // If it does, we should process it with the room events below. + if loopback != nil { + var loopbackRes api.InputRoomEventsResponse + err := r.InputRoomEvents(ctx, &api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{*loopback}, + }, &loopbackRes) + if err != nil { + return err + } + } + return nil +} + +func (r *RoomserverInternalAPI) processInviteEvent( + ctx context.Context, + ow *RoomserverInternalAPI, + input *api.PerformInviteRequest, +) (*api.InputRoomEvent, error) { + if input.Event.StateKey() == nil { + return nil, fmt.Errorf("invite must be a state event") + } + + roomID := input.Event.RoomID() + targetUserID := *input.Event.StateKey() + + log.WithFields(log.Fields{ + "event_id": input.Event.EventID(), + "room_id": roomID, + "room_version": input.RoomVersion, + "target_user_id": targetUserID, + }).Info("processing invite event") + + _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID) + isTargetLocalUser := domain == r.Cfg.Matrix.ServerName + + updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocalUser, input.RoomVersion) + if err != nil { + return nil, err + } + succeeded := false + defer func() { + txerr := sqlutil.EndTransaction(updater, &succeeded) + if err == nil && txerr != nil { + err = txerr + } + }() + + if updater.IsJoin() { + // If the user is joined to the room then that takes precedence over this + // invite event. It makes little sense to move a user that is already + // joined to the room into the invite state. + // This could plausibly happen if an invite request raced with a join + // request for a user. For example if a user was invited to a public + // room and they joined the room at the same time as the invite was sent. + // The other way this could plausibly happen is if an invite raced with + // a kick. For example if a user was kicked from a room in error and in + // response someone else in the room re-invited them then it is possible + // for the invite request to race with the leave event so that the + // target receives invite before it learns that it has been kicked. + // There are a few ways this could be plausibly handled in the roomserver. + // 1) Store the invite, but mark it as retired. That will result in the + // permanent rejection of that invite event. So even if the target + // user leaves the room and the invite is retransmitted it will be + // ignored. However a new invite with a new event ID would still be + // accepted. + // 2) Silently discard the invite event. This means that if the event + // was retransmitted at a later date after the target user had left + // the room we would accept the invite. However since we hadn't told + // the sending server that the invite had been discarded it would + // have no reason to attempt to retry. + // 3) Signal the sending server that the user is already joined to the + // room. + // For now we will implement option 2. Since in the abesence of a retry + // mechanism it will be equivalent to option 1, and we don't have a + // signalling mechanism to implement option 3. + return nil, &api.PerformError{ + Code: api.PerformErrorNoOperation, + Msg: "user is already joined to room", + } + } + + // Normally, with a federated invite, the federation sender would do + // the /v2/invite request (in which the remote server signs the invite) + // and then the signed event gets sent back to the roomserver as an input + // event. When the invite is local, we don't interact with the federation + // sender therefore we need to generate the loopback invite event for + // the room ourselves. + loopback, err := localInviteLoopback(ow, input) + if err != nil { + return nil, err + } + + event := input.Event.Unwrap() + if len(input.InviteRoomState) > 0 { + // If we were supplied with some invite room state already (which is + // most likely to be if the event came in over federation) then use + // that. + if err = event.SetUnsignedField("invite_room_state", input.InviteRoomState); err != nil { + return nil, err + } + } else { + // There's no invite room state, so let's have a go at building it + // up from local data (which is most likely to be if the event came + // from the CS API). If we know about the room then we can insert + // the invite room state, if we don't then we just fail quietly. + if irs, ierr := buildInviteStrippedState(ctx, r.DB, input); ierr == nil { + if err = event.SetUnsignedField("invite_room_state", irs); err != nil { + return nil, err + } + } + } + + outputUpdates, err := updateToInviteMembership(updater, &event, nil, input.Event.RoomVersion) + if err != nil { + return nil, err + } + + if err = ow.WriteOutputEvents(roomID, outputUpdates); err != nil { + return nil, err + } + + succeeded = true + return loopback, nil +} + +func localInviteLoopback( + ow *RoomserverInternalAPI, + input *api.PerformInviteRequest, +) (ire *api.InputRoomEvent, err error) { + if input.Event.StateKey() == nil { + return nil, errors.New("no state key on invite event") + } + ourServerName := string(ow.Cfg.Matrix.ServerName) + _, theirServerName, err := gomatrixserverlib.SplitID('@', *input.Event.StateKey()) + if err != nil { + return nil, err + } + // Check if the invite originated locally and is destined locally. + if input.Event.Origin() == ow.Cfg.Matrix.ServerName && string(theirServerName) == ourServerName { + rsEvent := input.Event.Sign( + ourServerName, + ow.Cfg.Matrix.KeyID, + ow.Cfg.Matrix.PrivateKey, + ).Headered(input.RoomVersion) + ire = &api.InputRoomEvent{ + Kind: api.KindNew, + Event: rsEvent, + AuthEventIDs: rsEvent.AuthEventIDs(), + SendAsServer: ourServerName, + TransactionID: nil, + } + } + return ire, nil +} + +func buildInviteStrippedState( + ctx context.Context, + db storage.Database, + input *api.PerformInviteRequest, +) ([]gomatrixserverlib.InviteV2StrippedState, error) { + roomNID, err := db.RoomNID(ctx, input.Event.RoomID()) + if err != nil || roomNID == 0 { + return nil, fmt.Errorf("room %q unknown", input.Event.RoomID()) + } + stateWanted := []gomatrixserverlib.StateKeyTuple{} + // "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included." + // https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member + for _, t := range []string{ + gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias, + gomatrixserverlib.MRoomAliases, gomatrixserverlib.MRoomJoinRules, + "m.room.avatar", + } { + stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{ + EventType: t, + StateKey: "", + }) + } + _, currentStateSnapshotNID, _, err := db.LatestEventIDs(ctx, roomNID) + if err != nil { + return nil, err + } + roomState := state.NewStateResolution(db) + stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( + ctx, currentStateSnapshotNID, stateWanted, + ) + if err != nil { + return nil, err + } + stateNIDs := []types.EventNID{} + for _, stateNID := range stateEntries { + stateNIDs = append(stateNIDs, stateNID.EventNID) + } + stateEvents, err := db.Events(ctx, stateNIDs) + if err != nil { + return nil, err + } + inviteState := []gomatrixserverlib.InviteV2StrippedState{ + gomatrixserverlib.NewInviteV2StrippedState(&input.Event.Event), + } + stateEvents = append(stateEvents, types.Event{Event: input.Event.Unwrap()}) + for _, event := range stateEvents { + inviteState = append(inviteState, gomatrixserverlib.NewInviteV2StrippedState(&event.Event)) + } + return inviteState, nil +} diff --git a/roomserver/internal/perform_join.go b/roomserver/internal/perform_join.go new file mode 100644 index 000000000..d409b6849 --- /dev/null +++ b/roomserver/internal/perform_join.go @@ -0,0 +1,278 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// PerformJoin handles joining matrix rooms, including over federation by talking to the federationsender. +func (r *RoomserverInternalAPI) PerformJoin( + ctx context.Context, + req *api.PerformJoinRequest, + res *api.PerformJoinResponse, +) { + roomID, err := r.performJoin(ctx, req) + if err != nil { + perr, ok := err.(*api.PerformError) + if ok { + res.Error = perr + } else { + res.Error = &api.PerformError{ + Msg: err.Error(), + } + } + } + res.RoomID = roomID +} + +func (r *RoomserverInternalAPI) performJoin( + ctx context.Context, + req *api.PerformJoinRequest, +) (string, error) { + _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return "", &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), + } + } + if domain != r.Cfg.Matrix.ServerName { + return "", &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), + } + } + if strings.HasPrefix(req.RoomIDOrAlias, "!") { + return r.performJoinRoomByID(ctx, req) + } + if strings.HasPrefix(req.RoomIDOrAlias, "#") { + return r.performJoinRoomByAlias(ctx, req) + } + return "", &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias), + } +} + +func (r *RoomserverInternalAPI) performJoinRoomByAlias( + ctx context.Context, + req *api.PerformJoinRequest, +) (string, error) { + // Get the domain part of the room alias. + _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) + if err != nil { + return "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias) + } + req.ServerNames = append(req.ServerNames, domain) + + // Check if this alias matches our own server configuration. If it + // doesn't then we'll need to try a federated join. + var roomID string + if domain != r.Cfg.Matrix.ServerName { + // The alias isn't owned by us, so we will need to try joining using + // a remote server. + dirReq := fsAPI.PerformDirectoryLookupRequest{ + RoomAlias: req.RoomIDOrAlias, // the room alias to lookup + ServerName: domain, // the server to ask + } + dirRes := fsAPI.PerformDirectoryLookupResponse{} + err = r.fsAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes) + if err != nil { + logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias) + return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err) + } + roomID = dirRes.RoomID + req.ServerNames = append(req.ServerNames, dirRes.ServerNames...) + } else { + // Otherwise, look up if we know this room alias locally. + roomID, err = r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias) + if err != nil { + return "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err) + } + } + + // If the room ID is empty then we failed to look up the alias. + if roomID == "" { + return "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias) + } + + // If we do, then pluck out the room ID and continue the join. + req.RoomIDOrAlias = roomID + return r.performJoinRoomByID(ctx, req) +} + +// TODO: Break this function up a bit +// nolint:gocyclo +func (r *RoomserverInternalAPI) performJoinRoomByID( + ctx context.Context, + req *api.PerformJoinRequest, +) (string, error) { + // Get the domain part of the room ID. + _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) + if err != nil { + return "", &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err), + } + } + req.ServerNames = append(req.ServerNames, domain) + + // Prepare the template for the join event. + userID := req.UserID + eb := gomatrixserverlib.EventBuilder{ + Type: gomatrixserverlib.MRoomMember, + Sender: userID, + StateKey: &userID, + RoomID: req.RoomIDOrAlias, + Redacts: "", + } + if err = eb.SetUnsigned(struct{}{}); err != nil { + return "", fmt.Errorf("eb.SetUnsigned: %w", err) + } + + // It is possible for the request to include some "content" for the + // event. We'll always overwrite the "membership" key, but the rest, + // like "display_name" or "avatar_url", will be kept if supplied. + if req.Content == nil { + req.Content = map[string]interface{}{} + } + req.Content["membership"] = gomatrixserverlib.Join + if err = eb.SetContent(req.Content); err != nil { + return "", fmt.Errorf("eb.SetContent: %w", err) + } + + // First work out if this is in response to an existing invite + // from a federated server. If it is then we avoid the situation + // where we might think we know about a room in the following + // section but don't know the latest state as all of our users + // have left. + isInvitePending, inviteSender, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID) + if err == nil && isInvitePending { + // Check if there's an invite pending. + _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) + if ierr != nil { + return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + + // Check that the domain isn't ours. If it's local then we don't + // need to do anything as our own copy of the room state will be + // up-to-date. + if inviterDomain != r.Cfg.Matrix.ServerName { + // Add the server of the person who invited us to the server list, + // as they should be a fairly good bet. + req.ServerNames = append(req.ServerNames, inviterDomain) + + // Perform a federated room join. + return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + } + } + + // Try to construct an actual join event from the template. + // If this succeeds then it is a sign that the room already exists + // locally on the homeserver. + // TODO: Check what happens if the room exists on the server + // but everyone has since left. I suspect it does the wrong thing. + buildRes := api.QueryLatestEventsAndStateResponse{} + event, err := eventutil.BuildEvent( + ctx, // the request context + &eb, // the template join event + r.Cfg, // the server configuration + time.Now(), // the event timestamp to use + r, // the roomserver API to use + &buildRes, // the query response + ) + + switch err { + case nil: + // The room join is local. Send the new join event into the + // roomserver. First of all check that the user isn't already + // a member of the room. + alreadyJoined := false + for _, se := range buildRes.StateEvents { + if membership, merr := se.Membership(); merr == nil { + if se.StateKey() != nil && *se.StateKey() == *event.StateKey() { + alreadyJoined = (membership == gomatrixserverlib.Join) + break + } + } + } + + // If we haven't already joined the room then send an event + // into the room changing our membership status. + if !alreadyJoined { + inputReq := api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{ + { + Kind: api.KindNew, + Event: event.Headered(buildRes.RoomVersion), + AuthEventIDs: event.AuthEventIDs(), + SendAsServer: string(r.Cfg.Matrix.ServerName), + }, + }, + } + inputRes := api.InputRoomEventsResponse{} + if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + var notAllowed *gomatrixserverlib.NotAllowed + if errors.As(err, ¬Allowed) { + return "", &api.PerformError{ + Code: api.PerformErrorNotAllowed, + Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err), + } + } + return "", fmt.Errorf("r.InputRoomEvents: %w", err) + } + } + + case eventutil.ErrRoomNoExists: + // The room doesn't exist. First of all check if the room is a local + // room. If it is then there's nothing more to do - the room just + // hasn't been created yet. + if domain == r.Cfg.Matrix.ServerName { + return "", &api.PerformError{ + Code: api.PerformErrorNoRoom, + Msg: fmt.Sprintf("Room ID %q does not exist", req.RoomIDOrAlias), + } + } + + // Perform a federated room join. + return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + + default: + // Something else went wrong. + return "", fmt.Errorf("Error joining local room: %q", err) + } + + // By this point, if req.RoomIDOrAlias contained an alias, then + // it will have been overwritten with a room ID by performJoinRoomByAlias. + // We should now include this in the response so that the CS API can + // return the right room ID. + return req.RoomIDOrAlias, nil +} + +func (r *RoomserverInternalAPI) performFederatedJoinRoomByID( + ctx context.Context, + req *api.PerformJoinRequest, +) error { + // Try joining by all of the supplied server names. + fedReq := fsAPI.PerformJoinRequest{ + RoomID: req.RoomIDOrAlias, // the room ID to try and join + UserID: req.UserID, // the user ID joining the room + ServerNames: req.ServerNames, // the server to try joining with + Content: req.Content, // the membership event content + } + fedRes := fsAPI.PerformJoinResponse{} + if err := r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes); err != nil { + return fmt.Errorf("Error joining federated room: %q", err) + } + + return nil +} diff --git a/roomserver/internal/perform_leave.go b/roomserver/internal/perform_leave.go new file mode 100644 index 000000000..880c8b203 --- /dev/null +++ b/roomserver/internal/perform_leave.go @@ -0,0 +1,207 @@ +package internal + +import ( + "context" + "fmt" + "strings" + "time" + + fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +// WriteOutputEvents implements OutputRoomEventWriter +func (r *RoomserverInternalAPI) PerformLeave( + ctx context.Context, + req *api.PerformLeaveRequest, + res *api.PerformLeaveResponse, +) error { + _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID) + } + if domain != r.Cfg.Matrix.ServerName { + return fmt.Errorf("User %q does not belong to this homeserver", req.UserID) + } + if strings.HasPrefix(req.RoomID, "!") { + return r.performLeaveRoomByID(ctx, req, res) + } + return fmt.Errorf("Room ID %q is invalid", req.RoomID) +} + +func (r *RoomserverInternalAPI) performLeaveRoomByID( + ctx context.Context, + req *api.PerformLeaveRequest, + res *api.PerformLeaveResponse, // nolint:unparam +) error { + // If there's an invite outstanding for the room then respond to + // that. + isInvitePending, senderUser, err := r.isInvitePending(ctx, req.RoomID, req.UserID) + if err == nil && isInvitePending { + return r.performRejectInvite(ctx, req, res, senderUser) + } + + // There's no invite pending, so first of all we want to find out + // if the room exists and if the user is actually in it. + latestReq := api.QueryLatestEventsAndStateRequest{ + RoomID: req.RoomID, + StateToFetch: []gomatrixserverlib.StateKeyTuple{ + { + EventType: gomatrixserverlib.MRoomMember, + StateKey: req.UserID, + }, + }, + } + latestRes := api.QueryLatestEventsAndStateResponse{} + if err = r.QueryLatestEventsAndState(ctx, &latestReq, &latestRes); err != nil { + return err + } + if !latestRes.RoomExists { + return fmt.Errorf("Room %q does not exist", req.RoomID) + } + + // Now let's see if the user is in the room. + if len(latestRes.StateEvents) == 0 { + return fmt.Errorf("User %q is not a member of room %q", req.UserID, req.RoomID) + } + membership, err := latestRes.StateEvents[0].Membership() + if err != nil { + return fmt.Errorf("Error getting membership: %w", err) + } + if membership != gomatrixserverlib.Join { + // TODO: should be able to handle "invite" in this case too, if + // it's a case of kicking or banning or such + return fmt.Errorf("User %q is not joined to the room (membership is %q)", req.UserID, membership) + } + + // Prepare the template for the leave event. + userID := req.UserID + eb := gomatrixserverlib.EventBuilder{ + Type: gomatrixserverlib.MRoomMember, + Sender: userID, + StateKey: &userID, + RoomID: req.RoomID, + Redacts: "", + } + if err = eb.SetContent(map[string]interface{}{"membership": "leave"}); err != nil { + return fmt.Errorf("eb.SetContent: %w", err) + } + if err = eb.SetUnsigned(struct{}{}); err != nil { + return fmt.Errorf("eb.SetUnsigned: %w", err) + } + + // We know that the user is in the room at this point so let's build + // a leave event. + // TODO: Check what happens if the room exists on the server + // but everyone has since left. I suspect it does the wrong thing. + buildRes := api.QueryLatestEventsAndStateResponse{} + event, err := eventutil.BuildEvent( + ctx, // the request context + &eb, // the template leave event + r.Cfg, // the server configuration + time.Now(), // the event timestamp to use + r, // the roomserver API to use + &buildRes, // the query response + ) + if err != nil { + return fmt.Errorf("eventutil.BuildEvent: %w", err) + } + + // Give our leave event to the roomserver input stream. The + // roomserver will process the membership change and notify + // downstream automatically. + inputReq := api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{ + { + Kind: api.KindNew, + Event: event.Headered(buildRes.RoomVersion), + AuthEventIDs: event.AuthEventIDs(), + SendAsServer: string(r.Cfg.Matrix.ServerName), + }, + }, + } + inputRes := api.InputRoomEventsResponse{} + if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + return fmt.Errorf("r.InputRoomEvents: %w", err) + } + + return nil +} + +func (r *RoomserverInternalAPI) performRejectInvite( + ctx context.Context, + req *api.PerformLeaveRequest, + res *api.PerformLeaveResponse, // nolint:unparam + senderUser string, +) error { + _, domain, err := gomatrixserverlib.SplitID('@', senderUser) + if err != nil { + return fmt.Errorf("User ID %q invalid: %w", senderUser, err) + } + + // Ask the federation sender to perform a federated leave for us. + leaveReq := fsAPI.PerformLeaveRequest{ + RoomID: req.RoomID, + UserID: req.UserID, + ServerNames: []gomatrixserverlib.ServerName{domain}, + } + leaveRes := fsAPI.PerformLeaveResponse{} + if err := r.fsAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { + return err + } + + // TODO: Withdraw the invite, so that the sync API etc are + // notified that we rejected it. + + return nil +} + +func (r *RoomserverInternalAPI) isInvitePending( + ctx context.Context, + roomID, userID string, +) (bool, string, error) { + // Look up the room NID for the supplied room ID. + roomNID, err := r.DB.RoomNID(ctx, roomID) + if err != nil { + return false, "", fmt.Errorf("r.DB.RoomNID: %w", err) + } + + // Look up the state key NID for the supplied user ID. + targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID}) + if err != nil { + return false, "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) + } + targetUserNID, targetUserFound := targetUserNIDs[userID] + if !targetUserFound { + return false, "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) + } + + // Let's see if we have an event active for the user in the room. If + // we do then it will contain a server name that we can direct the + // send_leave to. + senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID) + if err != nil { + return false, "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err) + } + if len(senderUserNIDs) == 0 { + return false, "", nil + } + + // Look up the user ID from the NID. + senderUsers, err := r.DB.EventStateKeys(ctx, senderUserNIDs) + if err != nil { + return false, "", fmt.Errorf("r.DB.EventStateKeys: %w", err) + } + if len(senderUsers) == 0 { + return false, "", fmt.Errorf("no senderUsers") + } + + senderUser, senderUserFound := senderUsers[senderUserNIDs[0]] + if !senderUserFound { + return false, "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers) + } + + return true, senderUser, nil +} diff --git a/roomserver/query/query.go b/roomserver/internal/query.go similarity index 53% rename from roomserver/query/query.go rename to roomserver/internal/query.go index 224d9fa22..4fc8e4c25 100644 --- a/roomserver/query/query.go +++ b/roomserver/internal/query.go @@ -14,96 +14,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -package query +package internal import ( "context" - "encoding/json" - "net/http" + "fmt" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/caching" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/state/database" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) -// RoomserverQueryAPIEventDB has a convenience API to fetch events directly by -// EventIDs. -type RoomserverQueryAPIEventDB interface { - // Look up the Events for a list of event IDs. Does not error if event was - // not found. - // Returns an error if the retrieval went wrong. - EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) -} - -// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API. -type RoomserverQueryAPIDatabase interface { - database.RoomStateDatabase - RoomserverQueryAPIEventDB - // Look up the numeric ID for the room. - // Returns 0 if the room doesn't exists. - // Returns an error if there was a problem talking to the database. - RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) - // Look up event references for the latest events in the room and the current state snapshot. - // Returns the latest events, the current state and the maximum depth of the latest events plus 1. - // Returns an error if there was a problem talking to the database. - LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, - ) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) - // Look up the numeric IDs for a list of events. - // Returns an error if there was a problem talking to the database. - EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) - // Lookup the event IDs for a batch of event numeric IDs. - // Returns an error if the retrieval went wrong. - EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) - // Lookup the membership of a given user in a given room. - // Returns the numeric ID of the latest membership event sent from this user - // in this room, along a boolean set to true if the user is still in this room, - // false if not. - // Returns an error if there was a problem talking to the database. - GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, - ) (membershipEventNID types.EventNID, stillInRoom bool, err error) - // Lookup the membership event numeric IDs for all user that are or have - // been members of a given room. Only lookup events of "join" membership if - // joinOnly is set to true. - // Returns an error if there was a problem talking to the database. - GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, - ) ([]types.EventNID, error) - // Look up the active invites targeting a user in a room and return the - // numeric state key IDs for the user IDs who sent them. - // Returns an error if there was a problem talking to the database. - GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, - ) (senderUserNIDs []types.EventStateKeyNID, err error) - // Look up the string event state keys for a list of numeric event state keys - // Returns an error if there was a problem talking to the database. - EventStateKeys( - context.Context, []types.EventStateKeyNID, - ) (map[types.EventStateKeyNID]string, error) - // Look up the room version for a given room. - GetRoomVersionForRoom( - ctx context.Context, roomID string, - ) (gomatrixserverlib.RoomVersion, error) -} - -// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI -type RoomserverQueryAPI struct { - DB RoomserverQueryAPIDatabase - ImmutableCache caching.ImmutableCache -} - -// QueryLatestEventsAndState implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryLatestEventsAndState( +// QueryLatestEventsAndState implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryLatestEventsAndState( ctx context.Context, request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, @@ -116,8 +45,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( roomState := state.NewStateResolution(r.DB) - response.QueryLatestEventsAndStateRequest = *request - roomNID, err := r.DB.RoomNID(ctx, request.RoomID) + roomNID, err := r.DB.RoomNIDExcludingStubs(ctx, request.RoomID) if err != nil { return err } @@ -162,8 +90,8 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( return nil } -// QueryStateAfterEvents implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryStateAfterEvents( +// QueryStateAfterEvents implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryStateAfterEvents( ctx context.Context, request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, @@ -176,8 +104,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( roomState := state.NewStateResolution(r.DB) - response.QueryStateAfterEventsRequest = *request - roomNID, err := r.DB.RoomNID(ctx, request.RoomID) + roomNID, err := r.DB.RoomNIDExcludingStubs(ctx, request.RoomID) if err != nil { return err } @@ -218,14 +145,12 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( return nil } -// QueryEventsByID implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryEventsByID( +// QueryEventsByID implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryEventsByID( ctx context.Context, request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { - response.QueryEventsByIDRequest = *request - eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) if err != nil { return err @@ -253,7 +178,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID( return nil } -func (r *RoomserverQueryAPI) loadStateEvents( +func (r *RoomserverInternalAPI) loadStateEvents( ctx context.Context, stateEntries []types.StateEntry, ) ([]gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) @@ -263,7 +188,7 @@ func (r *RoomserverQueryAPI) loadStateEvents( return r.loadEvents(ctx, eventNIDs) } -func (r *RoomserverQueryAPI) loadEvents( +func (r *RoomserverInternalAPI) loadEvents( ctx context.Context, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.Event, error) { stateEvents, err := r.DB.Events(ctx, eventNIDs) @@ -278,8 +203,8 @@ func (r *RoomserverQueryAPI) loadEvents( return result, nil } -// QueryMembershipForUser implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryMembershipForUser( +// QueryMembershipForUser implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryMembershipForUser( ctx context.Context, request *api.QueryMembershipForUserRequest, response *api.QueryMembershipForUserResponse, @@ -309,8 +234,8 @@ func (r *RoomserverQueryAPI) QueryMembershipForUser( return nil } -// QueryMembershipsForRoom implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryMembershipsForRoom( +// QueryMembershipsForRoom implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryMembershipsForRoom( ctx context.Context, request *api.QueryMembershipsForRoomRequest, response *api.QueryMembershipsForRoomResponse, @@ -335,16 +260,22 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( response.JoinEvents = []gomatrixserverlib.ClientEvent{} var events []types.Event + var stateEntries []types.StateEntry if stillInRoom { var eventNIDs []types.EventNID - eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly) + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly, false) if err != nil { return err } events, err = r.DB.Events(ctx, eventNIDs) } else { - events, err = r.getMembershipsBeforeEventNID(ctx, membershipEventNID, request.JoinedOnly) + stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID) + if err != nil { + logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") + return err + } + events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) } if err != nil { @@ -359,32 +290,30 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( return nil } -// getMembershipsBeforeEventNID takes the numeric ID of an event and fetches the state -// of the event's room as it was when this event was fired, then filters the state events to -// only keep the "m.room.member" events with a "join" membership. These events are returned. -// Returns an error if there was an issue fetching the events. -func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( - ctx context.Context, eventNID types.EventNID, joinedOnly bool, -) ([]types.Event, error) { - roomState := state.NewStateResolution(r.DB) - events := []types.Event{} +func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db) // Lookup the event NID - eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID}) + eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { return nil, err } eventIDs := []string{eIDs[eventNID]} - prevState, err := r.DB.StateAtEventIDs(ctx, eventIDs) + prevState, err := db.StateAtEventIDs(ctx, eventIDs) if err != nil { return nil, err } // Fetch the state as it was when this event was fired - stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, prevState) - if err != nil { - return nil, err - } + return roomState.LoadCombinedStateAfterEvents(ctx, prevState) +} + +// getMembershipsAtState filters the state events to +// only keep the "m.room.member" events with a "join" membership. These events are returned. +// Returns an error if there was an issue fetching the events. +func getMembershipsAtState( + ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, +) ([]types.Event, error) { var eventNIDs []types.EventNID for _, entry := range stateEntries { @@ -395,7 +324,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( } // Get all of the events in this state - stateEvents, err := r.DB.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, eventNIDs) if err != nil { return nil, err } @@ -405,6 +334,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( } // Filter the events to only keep the "join" membership events + var events []types.Event for _, event := range stateEvents { membership, err := event.Membership() if err != nil { @@ -419,42 +349,8 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( return events, nil } -// QueryInvitesForUser implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryInvitesForUser( - ctx context.Context, - request *api.QueryInvitesForUserRequest, - response *api.QueryInvitesForUserResponse, -) error { - roomNID, err := r.DB.RoomNID(ctx, request.RoomID) - if err != nil { - return err - } - - targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID}) - if err != nil { - return err - } - targetUserNID := targetUserNIDs[request.TargetUserID] - - senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID) - if err != nil { - return err - } - - senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs) - if err != nil { - return err - } - - for _, senderUserID := range senderUserIDs { - response.InviteSenderUserIDs = append(response.InviteSenderUserIDs, senderUserID) - } - - return nil -} - -// QueryServerAllowedToSeeEvent implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent( +// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( ctx context.Context, request *api.QueryServerAllowedToSeeEventRequest, response *api.QueryServerAllowedToSeeEventResponse, @@ -477,7 +373,7 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent( return } -func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent( +func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { roomState := state.NewStateResolution(r.DB) @@ -496,8 +392,8 @@ func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent( return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil } -// QueryMissingEvents implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryMissingEvents( +// QueryMissingEvents implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryMissingEvents( ctx context.Context, request *api.QueryMissingEventsRequest, response *api.QueryMissingEventsResponse, @@ -541,12 +437,19 @@ func (r *RoomserverQueryAPI) QueryMissingEvents( return err } -// QueryBackfill implements api.RoomServerQueryAPI -func (r *RoomserverQueryAPI) QueryBackfill( +// PerformBackfill implements api.RoomServerQueryAPI +func (r *RoomserverInternalAPI) PerformBackfill( ctx context.Context, - request *api.QueryBackfillRequest, - response *api.QueryBackfillResponse, + request *api.PerformBackfillRequest, + response *api.PerformBackfillResponse, ) error { + // if we are requesting the backfill then we need to do a federation hit + // TODO: we could be more sensible and fetch as many events we already have then request the rest + // which is what the syncapi does already. + if request.ServerName == r.ServerName { + return r.backfillViaFederation(ctx, request, response) + } + // someone else is requesting the backfill, try to service their request. var err error var front []string @@ -554,14 +457,8 @@ func (r *RoomserverQueryAPI) QueryBackfill( // defines the highest number of elements in the map below. visited := make(map[string]bool, request.Limit) - // The provided event IDs have already been seen by the request's emitter, - // and will be retrieved anyway, so there's no need to care about them if - // they appear in our exploration of the event tree. - for _, id := range request.EarliestEventsIDs { - visited[id] = true - } - - front = request.EarliestEventsIDs + // this will include these events which is what we want + front = request.PrevEventIDs() // Scan the event tree for events to send back. resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) @@ -588,13 +485,75 @@ func (r *RoomserverQueryAPI) QueryBackfill( return err } -func (r *RoomserverQueryAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { +func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error { + roomVer, err := r.DB.GetRoomVersionForRoom(ctx, req.RoomID) + if err != nil { + return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err) + } + requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities) + // Request 100 items regardless of what the query asks for. + // We don't want to go much higher than this. + // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass + // (so we don't need to hit /state_ids which the test has no listener for) + // Specifically the test "Outbound federation can backfill events" + events, err := gomatrixserverlib.RequestBackfill( + ctx, requester, + r.KeyRing, req.RoomID, roomVer, req.PrevEventIDs(), 100) + if err != nil { + return err + } + logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) + + // persist these new events - auth checks have already been done + roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) + if err != nil { + return err + } + + for _, ev := range backfilledEventMap { + // now add state for these events + stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()] + if !ok { + // this should be impossible as all events returned must have pass Step 5 of the PDU checks + // which requires a list of state IDs. + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks") + continue + } + var entries []types.StateEntry + if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil { + // attempt to fetch the missing events + r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs) + // try again + entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event") + return err + } + } + + var beforeStateSnapshotNID types.StateSnapshotNID + if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") + return err + } + if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid") + } + } + + // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. + + res.Events = events + return nil +} + +func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { roomNID, err := r.DB.RoomNID(ctx, roomID) if err != nil { return false, err } - eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true) + eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false) if err != nil { return false, err } @@ -610,9 +569,69 @@ func (r *RoomserverQueryAPI) isServerCurrentlyInRoom(ctx context.Context, server return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil } +// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just +// best effort. +func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, + backfillRequester *backfillRequester, stateIDs []string) { + + servers := backfillRequester.servers + + // work out which are missing + nidMap, err := r.DB.EventNIDs(ctx, stateIDs) + if err != nil { + util.GetLogger(ctx).WithError(err).Warn("cannot query missing events") + return + } + missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event + for _, id := range stateIDs { + if _, ok := nidMap[id]; !ok { + missingMap[id] = nil + } + } + util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers)) + + // fetch the events from federation. Loop the servers first so if we find one that works we stick with them + for _, srv := range servers { + for id, ev := range missingMap { + if ev != nil { + continue // already found + } + logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) + res, err := r.FedClient.GetEvent(ctx, srv, id) + if err != nil { + logger.WithError(err).Warn("failed to get event from server") + continue + } + loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) + if err != nil { + logger.WithError(err).Warn("failed to load and verify event") + continue + } + logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result) + for _, res := range result { + if res.Error != nil { + logger.WithError(err).Warn("event failed PDU checks") + continue + } + missingMap[id] = res.Event + } + } + } + + var newEvents []gomatrixserverlib.HeaderedEvent + for _, ev := range missingMap { + if ev != nil { + newEvents = append(newEvents, *ev) + } + } + util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) + persistEvents(ctx, r.DB, newEvents) +} + // TODO: Remove this when we have tests to assert correctness of this function // nolint:gocyclo -func (r *RoomserverQueryAPI) scanEventTree( +func (r *RoomserverInternalAPI) scanEventTree( ctx context.Context, front []string, visited map[string]bool, limit int, serverName gomatrixserverlib.ServerName, ) ([]types.EventNID, error) { @@ -624,7 +643,7 @@ func (r *RoomserverQueryAPI) scanEventTree( var pre string // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be) - // Currently, callers like QueryBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing + // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in // duplicate events being sent in response to /backfill requests. initialIgnoreList := make(map[string]bool, len(visited)) @@ -705,14 +724,13 @@ BFSLoop: return resultNIDs, err } -// QueryStateAndAuthChain implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryStateAndAuthChain( +// QueryStateAndAuthChain implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryStateAndAuthChain( ctx context.Context, request *api.QueryStateAndAuthChainRequest, response *api.QueryStateAndAuthChainResponse, ) error { - response.QueryStateAndAuthChainRequest = *request - roomNID, err := r.DB.RoomNID(ctx, request.RoomID) + roomNID, err := r.DB.RoomNIDExcludingStubs(ctx, request.RoomID) if err != nil { return err } @@ -741,7 +759,7 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( } authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe - authEvents, err := getAuthChain(ctx, r.DB, authEventIDs) + authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) if err != nil { return err } @@ -765,7 +783,7 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( return err } -func (r *RoomserverQueryAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { +func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { roomState := state.NewStateResolution(r.DB) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { @@ -788,12 +806,14 @@ func (r *RoomserverQueryAPI) loadStateAtEventIDs(ctx context.Context, eventIDs [ return r.loadStateEvents(ctx, stateEntries) } +type eventsFromIDs func(context.Context, []string) ([]types.Event, error) + // getAuthChain fetches the auth chain for the given auth events. An auth chain // is the list of all events that are referenced in the auth_events section, and // all their auth_events, recursively. The returned set of events contain the // given events. Will *not* error if we don't have all auth events. func getAuthChain( - ctx context.Context, dB RoomserverQueryAPIEventDB, authEventIDs []string, + ctx context.Context, fn eventsFromIDs, authEventIDs []string, ) ([]gomatrixserverlib.Event, error) { // List of event IDs to fetch. On each pass, these events will be requested // from the database and the `eventsToFetch` will be updated with any new @@ -804,7 +824,7 @@ func getAuthChain( for len(eventsToFetch) > 0 { // Try to retrieve the events from the database. - events, err := dB.EventsFromIDs(ctx, eventsToFetch) + events, err := fn(ctx, eventsToFetch) if err != nil { return nil, err } @@ -839,43 +859,37 @@ func getAuthChain( return authEvents, nil } -// QueryServersInRoomAtEvent implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryServersInRoomAtEvent( - ctx context.Context, - request *api.QueryServersInRoomAtEventRequest, - response *api.QueryServersInRoomAtEventResponse, -) error { - // getMembershipsBeforeEventNID requires a NID, so retrieving the NID for - // the event is necessary. - NIDs, err := r.DB.EventNIDs(ctx, []string{request.EventID}) - if err != nil { - return err +func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { + var roomNID types.RoomNID + backfilledEventMap := make(map[string]types.Event) + for _, ev := range events { + nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs()) + if err != nil { // this shouldn't happen as RequestBackfill already found them + logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events") + continue + } + authNids := make([]types.EventNID, len(nidMap)) + i := 0 + for _, nid := range nidMap { + authNids[i] = nid + i++ + } + var stateAtEvent types.StateAtEvent + roomNID, stateAtEvent, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") + continue + } + backfilledEventMap[ev.EventID()] = types.Event{ + EventNID: stateAtEvent.StateEntry.EventNID, + Event: ev.Unwrap(), + } } - - // Retrieve all "m.room.member" state events of "join" membership, which - // contains the list of users in the room before the event, therefore all - // the servers in it at that moment. - events, err := r.getMembershipsBeforeEventNID(ctx, NIDs[request.EventID], true) - if err != nil { - return err - } - - // Store the server names in a temporary map to avoid duplicates. - servers := make(map[gomatrixserverlib.ServerName]bool) - for _, event := range events { - servers[event.Origin()] = true - } - - // Populate the response. - for server := range servers { - response.Servers = append(response.Servers, server) - } - - return nil + return roomNID, backfilledEventMap } -// QueryRoomVersionCapabilities implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryRoomVersionCapabilities( +// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities( ctx context.Context, request *api.QueryRoomVersionCapabilitiesRequest, response *api.QueryRoomVersionCapabilitiesResponse, @@ -892,13 +906,13 @@ func (r *RoomserverQueryAPI) QueryRoomVersionCapabilities( return nil } -// QueryRoomVersionCapabilities implements api.RoomserverQueryAPI -func (r *RoomserverQueryAPI) QueryRoomVersionForRoom( +// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI +func (r *RoomserverInternalAPI) QueryRoomVersionForRoom( ctx context.Context, request *api.QueryRoomVersionForRoomRequest, response *api.QueryRoomVersionForRoomResponse, ) error { - if roomVersion, ok := r.ImmutableCache.GetRoomVersion(request.RoomID); ok { + if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok { response.RoomVersion = roomVersion return nil } @@ -908,193 +922,6 @@ func (r *RoomserverQueryAPI) QueryRoomVersionForRoom( return err } response.RoomVersion = roomVersion - r.ImmutableCache.StoreRoomVersion(request.RoomID, response.RoomVersion) + r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) return nil } - -// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux. -// nolint: gocyclo -func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle( - api.RoomserverQueryLatestEventsAndStatePath, - common.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { - var request api.QueryLatestEventsAndStateRequest - var response api.QueryLatestEventsAndStateResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryStateAfterEventsPath, - common.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse { - var request api.QueryStateAfterEventsRequest - var response api.QueryStateAfterEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryEventsByIDPath, - common.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { - var request api.QueryEventsByIDRequest - var response api.QueryEventsByIDResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryMembershipForUserPath, - common.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse { - var request api.QueryMembershipForUserRequest - var response api.QueryMembershipForUserResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryMembershipsForRoomPath, - common.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryMembershipsForRoomRequest - var response api.QueryMembershipsForRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryInvitesForUserPath, - common.MakeInternalAPI("queryInvitesForUser", func(req *http.Request) util.JSONResponse { - var request api.QueryInvitesForUserRequest - var response api.QueryInvitesForUserResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryInvitesForUser(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryServerAllowedToSeeEventPath, - common.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { - var request api.QueryServerAllowedToSeeEventRequest - var response api.QueryServerAllowedToSeeEventResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryMissingEventsPath, - common.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse { - var request api.QueryMissingEventsRequest - var response api.QueryMissingEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryStateAndAuthChainPath, - common.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse { - var request api.QueryStateAndAuthChainRequest - var response api.QueryStateAndAuthChainResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryBackfillPath, - common.MakeInternalAPI("QueryBackfill", func(req *http.Request) util.JSONResponse { - var request api.QueryBackfillRequest - var response api.QueryBackfillResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryBackfill(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryServersInRoomAtEventPath, - common.MakeInternalAPI("QueryServersInRoomAtEvent", func(req *http.Request) util.JSONResponse { - var request api.QueryServersInRoomAtEventRequest - var response api.QueryServersInRoomAtEventResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryServersInRoomAtEvent(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryRoomVersionCapabilitiesPath, - common.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { - var request api.QueryRoomVersionCapabilitiesRequest - var response api.QueryRoomVersionCapabilitiesResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - servMux.Handle( - api.RoomserverQueryRoomVersionForRoomPath, - common.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryRoomVersionForRoomRequest - var response api.QueryRoomVersionForRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryRoomVersionForRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) -} diff --git a/roomserver/query/query_test.go b/roomserver/internal/query_test.go similarity index 90% rename from roomserver/query/query_test.go rename to roomserver/internal/query_test.go index 7e040c6fb..92e008324 100644 --- a/roomserver/query/query_test.go +++ b/roomserver/internal/query_test.go @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -package query +package internal import ( "context" "encoding/json" "testing" - "github.com/matrix-org/dendrite/common/test" + "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) -// used to implement RoomserverQueryAPIEventDB to test getAuthChain +// used to implement RoomserverInternalAPIEventDB to test getAuthChain type getEventDB struct { eventMap map[string]gomatrixserverlib.Event } @@ -79,7 +79,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error { return nil } -// EventsFromIDs implements RoomserverQueryAPIEventDB +// EventsFromIDs implements RoomserverInternalAPIEventDB func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) { for _, evID := range eventIDs { res = append(res, types.Event{ @@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) { t.Fatalf("Failed to add events to db: %v", err) } - result, err := getAuthChain(context.TODO(), db, []string{"e"}) + result, err := getAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"}) if err != nil { t.Fatalf("getAuthChain failed: %v", err) } @@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) { t.Fatalf("Failed to add events to db: %v", err) } - result, err := getAuthChain(context.TODO(), db, []string{"e", "f"}) + result, err := getAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"}) if err != nil { t.Fatalf("getAuthChain failed: %v", err) } diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go new file mode 100644 index 000000000..8a2b1204c --- /dev/null +++ b/roomserver/inthttp/client.go @@ -0,0 +1,347 @@ +package inthttp + +import ( + "context" + "errors" + "fmt" + "net/http" + + fsInputAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/opentracing/opentracing-go" +) + +const ( + // Alias operations + RoomserverSetRoomAliasPath = "/roomserver/setRoomAlias" + RoomserverGetRoomIDForAliasPath = "/roomserver/GetRoomIDForAlias" + RoomserverGetAliasesForRoomIDPath = "/roomserver/GetAliasesForRoomID" + RoomserverGetCreatorIDForAliasPath = "/roomserver/GetCreatorIDForAlias" + RoomserverRemoveRoomAliasPath = "/roomserver/removeRoomAlias" + + // Input operations + RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" + + // Perform operations + RoomserverPerformInvitePath = "/roomserver/performInvite" + RoomserverPerformJoinPath = "/roomserver/performJoin" + RoomserverPerformLeavePath = "/roomserver/performLeave" + RoomserverPerformBackfillPath = "/roomserver/performBackfill" + + // Query operations + RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" + RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents" + RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" + RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" + RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" + RoomserverQueryServerAllowedToSeeEventPath = "/roomserver/queryServerAllowedToSeeEvent" + RoomserverQueryMissingEventsPath = "/roomserver/queryMissingEvents" + RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain" + RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities" + RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom" +) + +type httpRoomserverInternalAPI struct { + roomserverURL string + httpClient *http.Client + cache caching.RoomVersionCache +} + +// NewRoomserverClient creates a RoomserverInputAPI implemented by talking to a HTTP POST API. +// If httpClient is nil an error is returned +func NewRoomserverClient( + roomserverURL string, + httpClient *http.Client, + cache caching.RoomVersionCache, +) (api.RoomserverInternalAPI, error) { + if httpClient == nil { + return nil, errors.New("NewRoomserverInternalAPIHTTP: httpClient is ") + } + return &httpRoomserverInternalAPI{ + roomserverURL: roomserverURL, + httpClient: httpClient, + cache: cache, + }, nil +} + +// SetFederationSenderInputAPI no-ops in HTTP client mode as there is no chicken/egg scenario +func (h *httpRoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsInputAPI.FederationSenderInternalAPI) { +} + +// SetRoomAlias implements RoomserverAliasAPI +func (h *httpRoomserverInternalAPI) SetRoomAlias( + ctx context.Context, + request *api.SetRoomAliasRequest, + response *api.SetRoomAliasResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverSetRoomAliasPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// GetRoomIDForAlias implements RoomserverAliasAPI +func (h *httpRoomserverInternalAPI) GetRoomIDForAlias( + ctx context.Context, + request *api.GetRoomIDForAliasRequest, + response *api.GetRoomIDForAliasResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// GetAliasesForRoomID implements RoomserverAliasAPI +func (h *httpRoomserverInternalAPI) GetAliasesForRoomID( + ctx context.Context, + request *api.GetAliasesForRoomIDRequest, + response *api.GetAliasesForRoomIDResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// GetCreatorIDForAlias implements RoomserverAliasAPI +func (h *httpRoomserverInternalAPI) GetCreatorIDForAlias( + ctx context.Context, + request *api.GetCreatorIDForAliasRequest, + response *api.GetCreatorIDForAliasResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetCreatorIDForAlias") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverGetCreatorIDForAliasPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// RemoveRoomAlias implements RoomserverAliasAPI +func (h *httpRoomserverInternalAPI) RemoveRoomAlias( + ctx context.Context, + request *api.RemoveRoomAliasRequest, + response *api.RemoveRoomAliasResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// InputRoomEvents implements RoomserverInputAPI +func (h *httpRoomserverInternalAPI) InputRoomEvents( + ctx context.Context, + request *api.InputRoomEventsRequest, + response *api.InputRoomEventsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverInputRoomEventsPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) PerformInvite( + ctx context.Context, + request *api.PerformInviteRequest, + response *api.PerformInviteResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInvite") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformInvitePath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.Error = &api.PerformError{ + Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), + } + } +} + +func (h *httpRoomserverInternalAPI) PerformJoin( + ctx context.Context, + request *api.PerformJoinRequest, + response *api.PerformJoinResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoin") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformJoinPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.Error = &api.PerformError{ + Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), + } + } +} + +func (h *httpRoomserverInternalAPI) PerformLeave( + ctx context.Context, + request *api.PerformLeaveRequest, + response *api.PerformLeaveResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeave") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformLeavePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryLatestEventsAndState implements RoomserverQueryAPI +func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( + ctx context.Context, + request *api.QueryLatestEventsAndStateRequest, + response *api.QueryLatestEventsAndStateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryStateAfterEvents implements RoomserverQueryAPI +func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( + ctx context.Context, + request *api.QueryStateAfterEventsRequest, + response *api.QueryStateAfterEventsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryEventsByID implements RoomserverQueryAPI +func (h *httpRoomserverInternalAPI) QueryEventsByID( + ctx context.Context, + request *api.QueryEventsByIDRequest, + response *api.QueryEventsByIDResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryMembershipForUser implements RoomserverQueryAPI +func (h *httpRoomserverInternalAPI) QueryMembershipForUser( + ctx context.Context, + request *api.QueryMembershipForUserRequest, + response *api.QueryMembershipForUserResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryMembershipsForRoom implements RoomserverQueryAPI +func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom( + ctx context.Context, + request *api.QueryMembershipsForRoomRequest, + response *api.QueryMembershipsForRoomResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI +func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent( + ctx context.Context, + request *api.QueryServerAllowedToSeeEventRequest, + response *api.QueryServerAllowedToSeeEventResponse, +) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryMissingEvents implements RoomServerQueryAPI +func (h *httpRoomserverInternalAPI) QueryMissingEvents( + ctx context.Context, + request *api.QueryMissingEventsRequest, + response *api.QueryMissingEventsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryStateAndAuthChain implements RoomserverQueryAPI +func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain( + ctx context.Context, + request *api.QueryStateAndAuthChainRequest, + response *api.QueryStateAndAuthChainResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// PerformBackfill implements RoomServerQueryAPI +func (h *httpRoomserverInternalAPI) PerformBackfill( + ctx context.Context, + request *api.PerformBackfillRequest, + response *api.PerformBackfillResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBackfill") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformBackfillPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryRoomVersionCapabilities implements RoomServerQueryAPI +func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities( + ctx context.Context, + request *api.QueryRoomVersionCapabilitiesRequest, + response *api.QueryRoomVersionCapabilitiesResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionCapabilities") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryRoomVersionForRoom implements RoomServerQueryAPI +func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( + ctx context.Context, + request *api.QueryRoomVersionForRoomRequest, + response *api.QueryRoomVersionForRoomResponse, +) error { + if roomVersion, ok := h.cache.GetRoomVersion(request.RoomID); ok { + response.RoomVersion = roomVersion + return nil + } + + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionForRoom") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err == nil { + h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion) + } + return err +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go new file mode 100644 index 000000000..1c47e87e2 --- /dev/null +++ b/roomserver/inthttp/server.go @@ -0,0 +1,288 @@ +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/util" +) + +// AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux. +// nolint: gocyclo +func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { + internalAPIMux.Handle(RoomserverInputRoomEventsPath, + httputil.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse { + var request api.InputRoomEventsRequest + var response api.InputRoomEventsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.InputRoomEvents(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverPerformInvitePath, + httputil.MakeInternalAPI("performInvite", func(req *http.Request) util.JSONResponse { + var request api.PerformInviteRequest + var response api.PerformInviteResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + r.PerformInvite(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverPerformJoinPath, + httputil.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse { + var request api.PerformJoinRequest + var response api.PerformJoinResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + r.PerformJoin(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverPerformLeavePath, + httputil.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse { + var request api.PerformLeaveRequest + var response api.PerformLeaveResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.PerformLeave(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryLatestEventsAndStatePath, + httputil.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { + var request api.QueryLatestEventsAndStateRequest + var response api.QueryLatestEventsAndStateResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryStateAfterEventsPath, + httputil.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse { + var request api.QueryStateAfterEventsRequest + var response api.QueryStateAfterEventsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryEventsByIDPath, + httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { + var request api.QueryEventsByIDRequest + var response api.QueryEventsByIDResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryMembershipForUserPath, + httputil.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse { + var request api.QueryMembershipForUserRequest + var response api.QueryMembershipForUserResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryMembershipsForRoomPath, + httputil.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { + var request api.QueryMembershipsForRoomRequest + var response api.QueryMembershipsForRoomResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryServerAllowedToSeeEventPath, + httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { + var request api.QueryServerAllowedToSeeEventRequest + var response api.QueryServerAllowedToSeeEventResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryMissingEventsPath, + httputil.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse { + var request api.QueryMissingEventsRequest + var response api.QueryMissingEventsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryStateAndAuthChainPath, + httputil.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse { + var request api.QueryStateAndAuthChainRequest + var response api.QueryStateAndAuthChainResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverPerformBackfillPath, + httputil.MakeInternalAPI("PerformBackfill", func(req *http.Request) util.JSONResponse { + var request api.PerformBackfillRequest + var response api.PerformBackfillResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.PerformBackfill(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryRoomVersionCapabilitiesPath, + httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { + var request api.QueryRoomVersionCapabilitiesRequest + var response api.QueryRoomVersionCapabilitiesResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryRoomVersionForRoomPath, + httputil.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse { + var request api.QueryRoomVersionForRoomRequest + var response api.QueryRoomVersionForRoomResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryRoomVersionForRoom(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverSetRoomAliasPath, + httputil.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse { + var request api.SetRoomAliasRequest + var response api.SetRoomAliasResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverGetRoomIDForAliasPath, + httputil.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse { + var request api.GetRoomIDForAliasRequest + var response api.GetRoomIDForAliasResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverGetCreatorIDForAliasPath, + httputil.MakeInternalAPI("GetCreatorIDForAlias", func(req *http.Request) util.JSONResponse { + var request api.GetCreatorIDForAliasRequest + var response api.GetCreatorIDForAliasResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.GetCreatorIDForAlias(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverGetAliasesForRoomIDPath, + httputil.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { + var request api.GetAliasesForRoomIDRequest + var response api.GetAliasesForRoomIDResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverRemoveRoomAliasPath, + httputil.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { + var request api.RemoveRoomAliasRequest + var response api.RemoveRoomAliasResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index fa4f20626..427d5ff36 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -15,57 +15,43 @@ package roomserver import ( - "net/http" - + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/inthttp" + "github.com/matrix-org/gomatrixserverlib" - asQuery "github.com/matrix-org/dendrite/appservice/query" - "github.com/matrix-org/dendrite/common/basecomponent" - "github.com/matrix-org/dendrite/roomserver/alias" - "github.com/matrix-org/dendrite/roomserver/input" - "github.com/matrix-org/dendrite/roomserver/query" + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/dendrite/roomserver/internal" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/sirupsen/logrus" ) -// SetupRoomServerComponent sets up and registers HTTP handlers for the -// RoomServer component. Returns instances of the various roomserver APIs, -// allowing other components running in the same process to hit the query the -// APIs directly instead of having to use HTTP. -func SetupRoomServerComponent( - base *basecomponent.BaseDendrite, -) (api.RoomserverAliasAPI, api.RoomserverInputAPI, api.RoomserverQueryAPI) { - roomserverDB, err := storage.Open(string(base.Cfg.Database.RoomServer)) +// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions +// on the given input API. +func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI) { + inthttp.AddRoutes(intAPI, router) +} + +// NewInternalAPI returns a concerete implementation of the internal API. Callers +// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. +func NewInternalAPI( + base *setup.BaseDendrite, + keyRing gomatrixserverlib.JSONVerifier, + fedClient *gomatrixserverlib.FederationClient, +) api.RoomserverInternalAPI { + roomserverDB, err := storage.Open(string(base.Cfg.Database.RoomServer), base.Cfg.DbProperties()) if err != nil { logrus.WithError(err).Panicf("failed to connect to room server db") } - inputAPI := input.RoomserverInputAPI{ + return &internal.RoomserverInternalAPI{ DB: roomserverDB, + Cfg: base.Cfg, Producer: base.KafkaProducer, OutputRoomEventTopic: string(base.Cfg.Kafka.Topics.OutputRoomEvent), + Cache: base.Caches, + ServerName: base.Cfg.Matrix.ServerName, + FedClient: fedClient, + KeyRing: keyRing, } - - inputAPI.SetupHTTP(http.DefaultServeMux) - - queryAPI := query.RoomserverQueryAPI{ - DB: roomserverDB, - ImmutableCache: base.ImmutableCache, - } - - queryAPI.SetupHTTP(http.DefaultServeMux) - - asAPI := asQuery.AppServiceQueryAPI{Cfg: base.Cfg} - - aliasAPI := alias.RoomserverAliasAPI{ - DB: roomserverDB, - Cfg: base.Cfg, - InputAPI: &inputAPI, - QueryAPI: &queryAPI, - AppserviceAPI: &asAPI, - } - - aliasAPI.SetupHTTP(http.DefaultServeMux) - - return &aliasAPI, &inputAPI, &queryAPI } diff --git a/roomserver/state/database/database.go b/roomserver/state/database/database.go deleted file mode 100644 index 80f1b14f4..000000000 --- a/roomserver/state/database/database.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package database - -import ( - "context" - - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -// A RoomStateDatabase has the storage APIs needed to load state from the database -type RoomStateDatabase interface { - // Store the room state at an event in the database - AddState( - ctx context.Context, - roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, - state []types.StateEntry, - ) (types.StateSnapshotNID, error) - // Look up the state of a room at each event for a list of string event IDs. - // Returns an error if there is an error talking to the database - // Returns a types.MissingEventError if the room state for the event IDs aren't in the database - StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) - // Look up the numeric IDs for a list of string event types. - // Returns a map from string event type to numeric ID for the event type. - EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) - // Look up the numeric IDs for a list of string event state keys. - // Returns a map from string state key to numeric ID for the state key. - EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - // Look up the numeric state data IDs for each numeric state snapshot ID - // The returned slice is sorted by numeric state snapshot ID. - StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) - // Look up the state data for each numeric state data ID - // The returned slice is sorted by numeric state data ID. - StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) - // Look up the state data for the state key tuples for each numeric state block ID - // This is used to fetch a subset of the room state at a snapshot. - // If a block doesn't contain any of the requested tuples then it can be discarded from the result. - // The returned slice is sorted by numeric state block ID. - StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, - ) ([]types.StateEntryList, error) - // Look up the Events for a list of numeric event IDs. - // Returns a sorted list of events. - Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) - // Look up snapshot NID for an event ID string - SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) - // Look up a room version from the room NID. - GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) -} diff --git a/roomserver/state/shared/shared.go b/roomserver/state/shared/shared.go deleted file mode 100644 index a29b5e403..000000000 --- a/roomserver/state/shared/shared.go +++ /dev/null @@ -1 +0,0 @@ -package shared diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 3f68e0747..d5be4a901 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -22,7 +22,7 @@ import ( "sort" "time" - "github.com/matrix-org/dendrite/roomserver/state/database" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -31,10 +31,10 @@ import ( ) type StateResolution struct { - db database.RoomStateDatabase + db storage.Database } -func NewStateResolution(db database.RoomStateDatabase) StateResolution { +func NewStateResolution(db storage.Database) StateResolution { return StateResolution{ db: db, } @@ -86,7 +86,10 @@ func (v StateResolution) LoadStateAtEvent( ) ([]types.StateEntry, error) { snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { - return nil, err + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err) + } + if snapshotNID == 0 { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) } stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID) @@ -564,7 +567,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // 3) None of the previous events were state events and they all // have the same state, so this event has exactly the same state // as the previous events. - // This should be the common case. + // This should be the internal case. metrics.algorithm = "no_change" return metrics.stop(prevState.BeforeStateSnapshotNID, nil) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 50369d806..52e6a96b7 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -18,32 +18,125 @@ import ( "context" "github.com/matrix-org/dendrite/roomserver/api" - statedb "github.com/matrix-org/dendrite/roomserver/state/database" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - statedb.RoomStateDatabase + // Store the room state at an event in the database + AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, + ) (types.StateSnapshotNID, error) + // Look up the state of a room at each event for a list of string event IDs. + // Returns an error if there is an error talking to the database. + // The length of []types.StateAtEvent is guaranteed to equal the length of eventIDs if no error is returned. + // Returns a types.MissingEventError if the room state for the event IDs aren't in the database + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + // Look up the numeric IDs for a list of string event types. + // Returns a map from string event type to numeric ID for the event type. + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + // Look up the numeric IDs for a list of string event state keys. + // Returns a map from string state key to numeric ID for the state key. + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + // Look up the numeric state data IDs for each numeric state snapshot ID + // The returned slice is sorted by numeric state snapshot ID. + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + // Look up the state data for each numeric state data ID + // The returned slice is sorted by numeric state data ID. + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + // Look up the state data for the state key tuples for each numeric state block ID + // This is used to fetch a subset of the room state at a snapshot. + // If a block doesn't contain any of the requested tuples then it can be discarded from the result. + // The returned slice is sorted by numeric state block ID. + StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, + ) ([]types.StateEntryList, error) + // Look up the Events for a list of numeric event IDs. + // Returns a sorted list of events. + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) + // Look up snapshot NID for an event ID string + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + // Look up a room version from the room NID. + GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) + // Stores a matrix room event in the database StoreEvent(ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) + // Look up the state entries for a list of string event IDs + // Returns an error if the there is an error talking to the database + // Returns a types.MissingEventError if the event IDs aren't in the database. StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + // Look up the string event state keys for a list of numeric event state keys + // Returns an error if there was a problem talking to the database. EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + // Look up the numeric IDs for a list of events. + // Returns an error if there was a problem talking to the database. EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + // Set the state at an event. FIXME TODO: "at" SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + // Lookup the event IDs for a batch of event numeric IDs. + // Returns an error if the retrieval went wrong. EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + // Look up the latest events in a room in preparation for an update. + // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. + // Returns the latest events in the room and the last eventID sent to the log along with an updater. + // If this returns an error then no further action is required. GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) + // Look up event ID by transaction's info. + // This is used to determine if the room event is processed/processing already. + // Returns an empty string if no such event exists. GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error) + // Look up the numeric ID for the room. + // Returns 0 if the room doesn't exists. + // Returns an error if there was a problem talking to the database. RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) + // RoomNIDExcludingStubs is a special variation of RoomNID that will return 0 as if the room + // does not exist if the room has no latest events. This can happen when we've received an + // invite over federation for a room that we don't know anything else about yet. + RoomNIDExcludingStubs(ctx context.Context, roomID string) (types.RoomNID, error) + // Look up event references for the latest events in the room and the current state snapshot. + // Returns the latest events, the current state and the maximum depth of the latest events plus 1. + // Returns an error if there was a problem talking to the database. LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + // Look up the active invites targeting a user in a room and return the + // numeric state key IDs for the user IDs who sent them. + // Returns an error if there was a problem talking to the database. GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error) + // Save a given room alias with the room ID it refers to. + // Returns an error if there was a problem talking to the database. SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error + // Look up the room ID a given alias refers to. + // Returns an error if there was a problem talking to the database. GetRoomIDForAlias(ctx context.Context, alias string) (string, error) + // Look up all aliases referring to a given room ID. + // Returns an error if there was a problem talking to the database. GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) + // Get the user ID of the creator of an alias. + // Returns an error if there was a problem talking to the database. GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) + // Remove a given room alias. + // Returns an error if there was a problem talking to the database. RemoveRoomAlias(ctx context.Context, alias string) error - MembershipUpdater(ctx context.Context, roomID, targetUserID string, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) + // Build a membership updater for the target user in a room. + MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) + // Lookup the membership of a given user in a given room. + // Returns the numeric ID of the latest membership event sent from this user + // in this room, along a boolean set to true if the user is still in this room, + // false if not. + // Returns an error if there was a problem talking to the database. GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) - GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) + // Lookup the membership event numeric IDs for all user that are or have + // been members of a given room. Only lookup events of "join" membership if + // joinOnly is set to true. + // Returns an error if there was a problem talking to the database. + GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) + // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was + // not found. + // Returns an error if the retrieval went wrong. EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) + // Look up the room version for a given room. GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) } diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 616eaf318..7df175954 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -19,8 +19,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" - + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -58,43 +59,39 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventJSONSchema) +func NewPostgresEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} + _, err := db.Exec(eventJSONSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventJSONStatements) insertEventJSON( - ctx context.Context, eventNID types.EventNID, eventJSON []byte, +func (s *eventJSONStatements) InsertEventJSON( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } -type eventJSONPair struct { - EventNID types.EventNID - EventJSON []byte -} - -func (s *eventJSONStatements) bulkSelectEventJSON( +func (s *eventJSONStatements) BulkSelectEventJSON( ctx context.Context, eventNIDs []types.EventNID, -) ([]eventJSONPair, error) { +) ([]tables.EventJSONPair, error) { rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed") // We know that we will only get as many results as event NIDs // because of the unique constraint on event NIDs. // So we can allocate an array of the correct size now. // We might get fewer results than NIDs so we adjust the length of the slice before returning it. - results := make([]eventJSONPair, len(eventNIDs)) + results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { result := &results[i] diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 4c3496d91..500ff20e4 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -20,7 +20,10 @@ import ( "database/sql" "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -74,38 +77,39 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventStateKeysSchema) +func NewPostgresEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{} + _, err := db.Exec(eventStateKeysSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventStateKeyStatements) insertEventStateKeyNID( +func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt) + stmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) selectEventStateKeyNID( +func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( +func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( @@ -114,7 +118,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed") result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) for rows.Next() { @@ -128,7 +132,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( return result, rows.Err() } -func (s *eventStateKeyStatements) bulkSelectEventStateKey( +func (s *eventStateKeyStatements) BulkSelectEventStateKey( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) @@ -139,7 +143,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed") result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) for rows.Next() { diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index 6537a5457..037d98fe7 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -19,15 +19,16 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" - "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) const eventTypesSchema = ` -- Numeric versions of the event "type"s. Event types tend to be taken from a --- small common pool. Assigning each a numeric ID should reduce the amount of +-- small internal pool. Assigning each a numeric ID should reduce the amount of -- data that needs to be stored and fetched from the database. -- It also means that many operations can work with int64 arrays rather than -- string arrays which may help reduce GC pressure. @@ -42,7 +43,7 @@ const eventTypesSchema = ` -- Picking well-known numeric IDs for the events types that require special -- attention during state conflict resolution means that we write that code -- using numeric constants. --- It also means that the numeric IDs for common event types should be +-- It also means that the numeric IDs for internal event types should be -- consistent between different instances which might make ad-hoc debugging -- easier. -- Other event types are automatically assigned numeric IDs starting from 2**16. @@ -98,43 +99,44 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventTypesSchema) +func NewPostgresEventTypesTable(db *sql.DB) (tables.EventTypes, error) { + s := &eventTypeStatements{} + _, err := db.Exec(eventTypesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventTypeStatements) insertEventTypeNID( - ctx context.Context, eventType string, +func (s *eventTypeStatements) InsertEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + err := txn.Stmt(s.insertEventTypeNIDStmt).QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) selectEventTypeNID( - ctx context.Context, eventType string, +func (s *eventTypeStatements) SelectEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + err := txn.Stmt(s.selectEventTypeNIDStmt).QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) bulkSelectEventTypeNID( +func (s *eventTypeStatements) BulkSelectEventTypeNID( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") result := make(map[string]types.EventTypeNID, len(eventTypes)) for rows.Next() { diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index ecc35f37a..bdbf5e7cb 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -21,7 +21,10 @@ import ( "fmt" "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -136,13 +139,14 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func (s *eventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventsSchema) +func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{} + _, err := db.Exec(eventsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, @@ -157,11 +161,12 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventStatements) insertEvent( +func (s *eventStatements) InsertEvent( ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, @@ -179,8 +184,8 @@ func (s *eventStatements) insertEvent( return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } -func (s *eventStatements) selectEvent( - ctx context.Context, eventID string, +func (s *eventStatements) SelectEvent( + ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 @@ -190,14 +195,14 @@ func (s *eventStatements) selectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError -func (s *eventStatements) bulkSelectStateEventByID( +func (s *eventStatements) BulkSelectStateEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") // We know that we will only get as many results as event IDs // because of the unique constraint on event IDs. // So we can allocate an array of the correct size now. @@ -222,7 +227,7 @@ func (s *eventStatements) bulkSelectStateEventByID( // We don't know which ones were missing because we don't return the string IDs in the query. // However it should be possible debug this by replaying queries or entries from the input kafka logs. // If this turns out to be impossible and we do need the debug information here, it would be better - // to do it as a separate query rather than slowing down/complicating the common case. + // to do it as a separate query rather than slowing down/complicating the internal case. return nil, types.MissingEventError( fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)), ) @@ -233,14 +238,14 @@ func (s *eventStatements) bulkSelectStateEventByID( // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. -func (s *eventStatements) bulkSelectStateAtEventByID( +func (s *eventStatements) BulkSelectStateAtEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventByID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventByID: rows.close() failed") results := make([]types.StateAtEvent, len(eventIDs)) i := 0 for ; rows.Next(); i++ { @@ -270,44 +275,44 @@ func (s *eventStatements) bulkSelectStateAtEventByID( return results, nil } -func (s *eventStatements) updateEventState( +func (s *eventStatements) UpdateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) return err } -func (s *eventStatements) selectEventSentToOutput( +func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { - stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt) err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) return } -func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt) +func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { + stmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := stmt.ExecContext(ctx, int64(eventNID)) return err } -func (s *eventStatements) selectEventID( +func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { - stmt := common.TxStmt(txn, s.selectEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventIDStmt) err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) return } -func (s *eventStatements) bulkSelectStateAtEventAndReference( +func (s *eventStatements) BulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { - stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") results := make([]types.StateAtEventAndReference, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { @@ -341,14 +346,14 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( return results, nil } -func (s *eventStatements) bulkSelectEventReference( - ctx context.Context, eventNIDs []types.EventNID, +func (s *eventStatements) BulkSelectEventReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { @@ -367,12 +372,12 @@ func (s *eventStatements) bulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") results := make(map[types.EventNID]string, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { @@ -394,12 +399,12 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") results := make(map[string]types.EventNID, len(eventIDs)) for rows.Next() { var eventID string @@ -412,7 +417,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str return results, rows.Err() } -func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { +func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { var result int64 stmt := s.selectMaxEventDepthStmt err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result) @@ -422,11 +427,10 @@ func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []t return result, nil } -func (s *eventStatements) selectRoomNIDForEventNID( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +func (s *eventStatements) SelectRoomNIDForEventNID( + ctx context.Context, eventNID types.EventNID, ) (roomNID types.RoomNID, err error) { - selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) - err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) + err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) return } diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index f764b1561..048a094dc 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -19,7 +19,10 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -79,26 +82,27 @@ type inviteStatements struct { updateInviteRetiredStmt *sql.Stmt } -func (s *inviteStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteSchema) +func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} + _, err := db.Exec(inviteSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *inviteStatements) insertInviteEvent( +func (s *inviteStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { - result, err := common.TxStmt(txn, s.insertInviteEventStmt).ExecContext( + result, err := sqlutil.TxStmt(txn, s.insertInviteEventStmt).ExecContext( ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ) if err != nil { @@ -111,16 +115,16 @@ func (s *inviteStatements) insertInviteEvent( return count != 0, nil } -func (s *inviteStatements) updateInviteRetired( +func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { - stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) + stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "updateInviteRetired: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "updateInviteRetired: rows.close() failed") var eventIDs []string for rows.Next() { @@ -133,8 +137,8 @@ func (s *inviteStatements) updateInviteRetired( return eventIDs, rows.Err() } -// selectInviteActiveForUserInRoom returns a list of sender state key NIDs -func (s *inviteStatements) selectInviteActiveForUserInRoom( +// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs +func (s *inviteStatements) SelectInviteActiveForUserInRoom( ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, error) { @@ -144,7 +148,7 @@ func (s *inviteStatements) selectInviteActiveForUserInRoom( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") var result []types.EventStateKeyNID for rows.Next() { var senderUserNID int64 diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 9c8a4c259..13cef638f 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -19,18 +19,13 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -type membershipState int64 - -const ( - membershipStateLeaveOrBan membershipState = 1 - membershipStateInvite membershipState = 2 - membershipStateJoin membershipState = 3 -) - const membershipSchema = ` -- The membership table is used to coordinate updates between the invite table -- and the room state tables. @@ -59,6 +54,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( -- This NID is updated if the join event gets updated (e.g. profile update), -- or if the user leaves/joins the room. event_nid BIGINT NOT NULL DEFAULT 0, + -- Local target is true if the target_nid refers to a local user rather than + -- a federated one. This is an optimisation for resetting state on federated + -- room joins. + target_local BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` @@ -66,8 +65,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid)" + - " VALUES ($1, $2)" + + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + + " VALUES ($1, $2, $3)" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + @@ -78,10 +77,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" +const selectLocalMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + + " AND target_local = true" + const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" +const selectLocalMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + + " AND target_local = true" + const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" @@ -91,67 +100,79 @@ const updateMembershipSQL = "" + " WHERE room_nid = $1 AND target_nid = $2" type membershipStatements struct { - insertMembershipStmt *sql.Stmt - selectMembershipForUpdateStmt *sql.Stmt - selectMembershipFromRoomAndTargetStmt *sql.Stmt - selectMembershipsFromRoomAndMembershipStmt *sql.Stmt - selectMembershipsFromRoomStmt *sql.Stmt - updateMembershipStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + selectLocalMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt } -func (s *membershipStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(membershipSchema) +func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} + _, err := db.Exec(membershipSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *membershipStatements) insertMembership( +func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + localTarget bool, ) error { - stmt := common.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) return err } -func (s *membershipStatements) selectMembershipForUpdate( +func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (membership membershipState, err error) { - err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( +) (membership tables.MembershipState, err error) { + err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership) return } -func (s *membershipStatements) selectMembershipFromRoomAndTarget( +func (s *membershipStatements) SelectMembershipFromRoomAndTarget( ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership membershipState, err error) { +) (eventNID types.EventNID, membership tables.MembershipState, err error) { err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) return } -func (s *membershipStatements) selectMembershipsFromRoom( - ctx context.Context, roomNID types.RoomNID, +func (s *membershipStatements) SelectMembershipsFromRoom( + ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID) + var stmt *sql.Stmt + if localOnly { + stmt = s.selectLocalMembershipsFromRoomStmt + } else { + stmt = s.selectMembershipsFromRoomStmt + } + rows, err := stmt.QueryContext(ctx, roomNID) if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed") for rows.Next() { var eNID types.EventNID @@ -163,16 +184,22 @@ func (s *membershipStatements) selectMembershipsFromRoom( return eventNIDs, rows.Err() } -func (s *membershipStatements) selectMembershipsFromRoomAndMembership( +func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( ctx context.Context, - roomNID types.RoomNID, membership membershipState, + roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - stmt := s.selectMembershipsFromRoomAndMembershipStmt - rows, err := stmt.QueryContext(ctx, roomNID, membership) + var rows *sql.Rows + var stmt *sql.Stmt + if localOnly { + stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt + } else { + stmt = s.selectMembershipsFromRoomAndMembershipStmt + } + rows, err = stmt.QueryContext(ctx, roomNID, membership) if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed") for rows.Next() { var eNID types.EventNID @@ -184,13 +211,13 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( return eventNIDs, rows.Err() } -func (s *membershipStatements) updateMembership( +func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership membershipState, + senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { - _, err := common.TxStmt(txn, s.updateMembershipStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, ) return err diff --git a/roomserver/storage/postgres/prepare.go b/roomserver/storage/postgres/prepare.go deleted file mode 100644 index 70b6e5161..000000000 --- a/roomserver/storage/postgres/prepare.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "database/sql" -) - -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string -} - -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { - for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { - return - } - } - return -} diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index 4c21b3081..1a4ba6732 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -19,7 +19,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -63,26 +65,27 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func (s *previousEventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(previousEventSchema) +func NewPostgresPreviousEventsTable(db *sql.DB) (tables.PreviousEvents, error) { + s := &previousEventStatements{} + _, err := db.Exec(previousEventSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *previousEventStatements) insertPreviousEvent( +func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - stmt := common.TxStmt(txn, s.insertPreviousEventStmt) + stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err := stmt.ExecContext( ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), ) @@ -91,10 +94,10 @@ func (s *previousEventStatements) insertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. -func (s *previousEventStatements) selectPreviousEventExists( +func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { var ok int64 - stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt) + stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) } diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index c37f383c9..85042c54f 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -19,7 +19,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const roomAliasesSchema = ` @@ -59,28 +61,29 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomAliasesSchema) +func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { + s := &roomAliasesStatements{} + _, err := db.Exec(roomAliasesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *roomAliasesStatements) insertRoomAlias( +func (s *roomAliasesStatements) InsertRoomAlias( ctx context.Context, alias string, roomID string, creatorUserID string, ) (err error) { _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) return } -func (s *roomAliasesStatements) selectRoomIDFromAlias( +func (s *roomAliasesStatements) SelectRoomIDFromAlias( ctx context.Context, alias string, ) (roomID string, err error) { err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) @@ -90,14 +93,14 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias( return } -func (s *roomAliasesStatements) selectAliasesFromRoomID( +func (s *roomAliasesStatements) SelectAliasesFromRoomID( ctx context.Context, roomID string, ) ([]string, error) { rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") var aliases []string for rows.Next() { @@ -111,7 +114,7 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( return aliases, rows.Err() } -func (s *roomAliasesStatements) selectCreatorIDFromAlias( +func (s *roomAliasesStatements) SelectCreatorIDFromAlias( ctx context.Context, alias string, ) (creatorID string, err error) { err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) @@ -121,7 +124,7 @@ func (s *roomAliasesStatements) selectCreatorIDFromAlias( return } -func (s *roomAliasesStatements) deleteRoomAlias( +func (s *roomAliasesStatements) DeleteRoomAlias( ctx context.Context, alias string, ) (err error) { _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index ef5b510c9..8e00cfdb8 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -21,7 +21,9 @@ import ( "errors" "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -82,12 +84,13 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomsSchema) +func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { + s := &roomStatements{} + _, err := db.Exec(roomsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, @@ -95,30 +98,30 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *roomStatements) insertRoomNID( +func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { var roomNID int64 - stmt := common.TxStmt(txn, s.insertRoomNIDStmt) + stmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) err := stmt.QueryRowContext(ctx, roomID, roomVersion).Scan(&roomNID) return types.RoomNID(roomNID), err } -func (s *roomStatements) selectRoomNID( +func (s *roomStatements) SelectRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { var roomNID int64 - stmt := common.TxStmt(txn, s.selectRoomNIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } -func (s *roomStatements) selectLatestEventNIDs( - ctx context.Context, roomNID types.RoomNID, +func (s *roomStatements) SelectLatestEventNIDs( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var stateSnapshotNID int64 @@ -134,13 +137,13 @@ func (s *roomStatements) selectLatestEventNIDs( return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) selectLatestEventsNIDsForUpdate( +func (s *roomStatements) SelectLatestEventsNIDsForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var lastEventSentNID int64 var stateSnapshotNID int64 - stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) if err != nil { return nil, 0, 0, err @@ -152,7 +155,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate( return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) updateLatestEventNIDs( +func (s *roomStatements) UpdateLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, @@ -160,7 +163,7 @@ func (s *roomStatements) updateLatestEventNIDs( lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt) + stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) _, err := stmt.ExecContext( ctx, roomNID, @@ -171,11 +174,11 @@ func (s *roomStatements) updateLatestEventNIDs( return err } -func (s *roomStatements) selectRoomVersionForRoomID( +func (s *roomStatements) SelectRoomVersionForRoomID( ctx context.Context, txn *sql.Tx, roomID string, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := common.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") @@ -183,12 +186,11 @@ func (s *roomStatements) selectRoomVersionForRoomID( return roomVersion, err } -func (s *roomStatements) selectRoomVersionForRoomNID( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +func (s *roomStatements) SelectRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) - err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") } diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go deleted file mode 100644 index 5956886ce..000000000 --- a/roomserver/storage/postgres/sql.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "database/sql" -) - -type statements struct { - eventTypeStatements - eventStateKeyStatements - roomStatements - eventStatements - eventJSONStatements - stateSnapshotStatements - stateBlockStatements - previousEventStatements - roomAliasesStatements - inviteStatements - membershipStatements - transactionStatements -} - -func (s *statements) prepare(db *sql.DB) error { - var err error - - for _, prepare := range []func(db *sql.DB) error{ - s.eventTypeStatements.prepare, - s.eventStateKeyStatements.prepare, - s.roomStatements.prepare, - s.eventStatements.prepare, - s.eventJSONStatements.prepare, - s.stateSnapshotStatements.prepare, - s.stateBlockStatements.prepare, - s.previousEventStatements.prepare, - s.roomAliasesStatements.prepare, - s.inviteStatements.prepare, - s.membershipStatements.prepare, - s.transactionStatements.prepare, - } { - if err = prepare(db); err != nil { - return err - } - } - - return nil -} diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index b9246b763..d618686f7 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -21,9 +21,10 @@ import ( "fmt" "sort" - "github.com/matrix-org/dendrite/common" - "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -31,7 +32,7 @@ import ( const stateDataSchema = ` -- The state data map. -- Designed to give enough information to run the state resolution algorithm --- without hitting the database in the common case. +-- without hitting the database in the internal case. -- TODO: Is it worth replacing the unique btree index with a covering index so -- that postgres could lookup the state using an index-only scan? -- The type and state_key are included in the index to make it easier to @@ -87,25 +88,30 @@ type stateBlockStatements struct { bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(stateDataSchema) +func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { + s := &stateBlockStatements{} + _, err := db.Exec(stateDataSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *stateBlockStatements) bulkInsertStateData( +func (s *stateBlockStatements) BulkInsertStateData( ctx context.Context, - stateBlockNID types.StateBlockNID, + txn *sql.Tx, entries []types.StateEntry, -) error { +) (types.StateBlockNID, error) { + stateBlockNID, err := s.selectNextStateBlockNID(ctx) + if err != nil { + return 0, err + } for _, entry := range entries { _, err := s.insertStateDataStmt.ExecContext( ctx, @@ -115,10 +121,10 @@ func (s *stateBlockStatements) bulkInsertStateData( int64(entry.EventNID), ) if err != nil { - return err + return 0, err } } - return nil + return stateBlockNID, nil } func (s *stateBlockStatements) selectNextStateBlockNID( @@ -129,7 +135,7 @@ func (s *stateBlockStatements) selectNextStateBlockNID( return types.StateBlockNID(stateBlockNID), err } -func (s *stateBlockStatements) bulkSelectStateBlockEntries( +func (s *stateBlockStatements) BulkSelectStateBlockEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { nids := make([]int64, len(stateBlockNIDs)) @@ -140,7 +146,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") results := make([]types.StateEntryList, len(stateBlockNIDs)) // current is a pointer to the StateEntryList to append the state entries to. @@ -180,7 +186,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( return results, err } -func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( +func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, @@ -199,7 +205,7 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") var results []types.StateEntryList var current types.StateEntryList diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index a1f26e228..0f8f1c51e 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -21,6 +21,8 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -64,30 +66,31 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(stateSnapshotSchema) +func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { + s := &stateSnapshotStatements{} + _, err := db.Exec(stateSnapshotSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *stateSnapshotStatements) insertState( - ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, +func (s *stateSnapshotStatements) InsertState( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ) (stateNID types.StateSnapshotNID, err error) { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) return } -func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( +func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]int64, len(stateNIDs)) diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 6f2b96610..d76ee0a92 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -9,780 +9,98 @@ // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implie // See the License for the specific language governing permissions and // limitations under the License. package postgres import ( - "context" "database/sql" - "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" // Import the postgres database driver. _ "github.com/lib/pq" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/roomserver/storage/shared" ) // A Database is used to store room events and stream offsets. type Database struct { - statements statements - db *sql.DB + shared.Database } // Open a postgres database. -func Open(dataSourceName string) (*Database, error) { +// nolint: gocyclo +func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { var d Database + var db *sql.DB var err error - if d.db, err = sqlutil.Open("postgres", dataSourceName); err != nil { + if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } - if err = d.statements.prepare(d.db); err != nil { + eventStateKeys, err := NewPostgresEventStateKeysTable(db) + if err != nil { return nil, err } + eventTypes, err := NewPostgresEventTypesTable(db) + if err != nil { + return nil, err + } + eventJSON, err := NewPostgresEventJSONTable(db) + if err != nil { + return nil, err + } + events, err := NewPostgresEventsTable(db) + if err != nil { + return nil, err + } + rooms, err := NewPostgresRoomsTable(db) + if err != nil { + return nil, err + } + transactions, err := NewPostgresTransactionsTable(db) + if err != nil { + return nil, err + } + stateBlock, err := NewPostgresStateBlockTable(db) + if err != nil { + return nil, err + } + stateSnapshot, err := NewPostgresStateSnapshotTable(db) + if err != nil { + return nil, err + } + roomAliases, err := NewPostgresRoomAliasesTable(db) + if err != nil { + return nil, err + } + prevEvents, err := NewPostgresPreviousEventsTable(db) + if err != nil { + return nil, err + } + invites, err := NewPostgresInvitesTable(db) + if err != nil { + return nil, err + } + membership, err := NewPostgresMembershipTable(db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: db, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + EventsTable: events, + RoomsTable: rooms, + TransactionsTable: transactions, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + } return &d, nil } - -// StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent( - ctx context.Context, event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, -) (types.RoomNID, types.StateAtEvent, error) { - var ( - roomNID types.RoomNID - eventTypeNID types.EventTypeNID - eventStateKeyNID types.EventStateKeyNID - eventNID types.EventNID - stateNID types.StateSnapshotNID - err error - ) - - if txnAndSessionID != nil { - if err = d.statements.insertTransaction( - ctx, txnAndSessionID.TransactionID, - txnAndSessionID.SessionID, event.Sender(), event.EventID(), - ); err != nil { - return 0, types.StateAtEvent{}, err - } - } - - // TODO: Here we should aim to have two different code paths for new rooms - // vs existing ones. - - // Get the default room version. If the client doesn't supply a room_version - // then we will use our configured default to create the room. - // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom - // Note that the below logic depends on the m.room.create event being the - // first event that is persisted to the database when creating or joining a - // room. - var roomVersion gomatrixserverlib.RoomVersion - if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return 0, types.StateAtEvent{}, err - } - - if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID(), roomVersion); err != nil { - return 0, types.StateAtEvent{}, err - } - - if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil { - return 0, types.StateAtEvent{}, err - } - - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil { - return 0, types.StateAtEvent{}, err - } - } - - if eventNID, stateNID, err = d.statements.insertEvent( - ctx, - roomNID, - eventTypeNID, - eventStateKeyNID, - event.EventID(), - event.EventReference().EventSHA256, - authEventNIDs, - event.Depth(), - ); err != nil { - if err == sql.ErrNoRows { - // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) - } - if err != nil { - return 0, types.StateAtEvent{}, err - } - } - - if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { - return 0, types.StateAtEvent{}, err - } - - return roomNID, types.StateAtEvent{ - BeforeStateSnapshotNID: stateNID, - StateEntry: types.StateEntry{ - StateKeyTuple: types.StateKeyTuple{ - EventTypeNID: eventTypeNID, - EventStateKeyNID: eventStateKeyNID, - }, - EventNID: eventNID, - }, - }, nil -} - -func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( - gomatrixserverlib.RoomVersion, error, -) { - var err error - var roomVersion gomatrixserverlib.RoomVersion - // Look for m.room.create events. - if event.Type() != gomatrixserverlib.MRoomCreate { - return gomatrixserverlib.RoomVersion(""), nil - } - roomVersion = gomatrixserverlib.RoomVersionV1 - var createContent gomatrixserverlib.CreateContent - // The m.room.create event contains an optional "room_version" key in - // the event content, so we need to unmarshal that first. - if err = json.Unmarshal(event.Content(), &createContent); err != nil { - return gomatrixserverlib.RoomVersion(""), err - } - // A room version was specified in the event content? - if createContent.RoomVersion != nil { - roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion) - } - return roomVersion, err -} - -func (d *Database) assignRoomNID( - ctx context.Context, txn *sql.Tx, - roomID string, roomVersion gomatrixserverlib.RoomVersion, -) (types.RoomNID, error) { - // Check if we already have a numeric ID in the database. - roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - } - } - return roomNID, err -} - -func (d *Database) assignEventTypeNID( - ctx context.Context, eventType string, -) (types.EventTypeNID, error) { - // Check if we already have a numeric ID in the database. - eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType) - } - } - return eventTypeNID, err -} - -func (d *Database) assignStateKeyNID( - ctx context.Context, txn *sql.Tx, eventStateKey string, -) (types.EventStateKeyNID, error) { - // Check if we already have a numeric ID in the database. - eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - } - } - return eventStateKeyNID, err -} - -// StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return d.statements.bulkSelectStateEventByID(ctx, eventIDs) -} - -// EventTypeNIDs implements state.RoomStateDatabase -func (d *Database) EventTypeNIDs( - ctx context.Context, eventTypes []string, -) (map[string]types.EventTypeNID, error) { - return d.statements.bulkSelectEventTypeNID(ctx, eventTypes) -} - -// EventStateKeyNIDs implements state.RoomStateDatabase -func (d *Database) EventStateKeyNIDs( - ctx context.Context, eventStateKeys []string, -) (map[string]types.EventStateKeyNID, error) { - return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys) -} - -// EventStateKeys implements query.RoomserverQueryAPIDatabase -func (d *Database) EventStateKeys( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, -) (map[types.EventStateKeyNID]string, error) { - return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs) -} - -// EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return d.statements.bulkSelectEventNID(ctx, eventIDs) -} - -// Events implements input.EventDatabase -func (d *Database) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) - if err != nil { - return nil, err - } - results := make([]types.Event, len(eventJSONs)) - for i, eventJSON := range eventJSONs { - var roomNID types.RoomNID - var roomVersion gomatrixserverlib.RoomVersion - result := &results[i] - result.EventNID = eventJSON.EventNID - roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) - if err != nil { - return nil, err - } - roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, nil, roomNID) - if err != nil { - return nil, err - } - result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( - eventJSON.EventJSON, false, roomVersion, - ) - if err != nil { - return nil, err - } - } - return results, nil -} - -// AddState implements input.EventDatabase -func (d *Database) AddState( - ctx context.Context, - roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, - state []types.StateEntry, -) (types.StateSnapshotNID, error) { - if len(state) > 0 { - stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx) - if err != nil { - return 0, err - } - if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil { - return 0, err - } - stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) - } - - return d.statements.insertState(ctx, roomNID, stateBlockNIDs) -} - -// SetState implements input.EventDatabase -func (d *Database) SetState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, -) error { - return d.statements.updateEventState(ctx, eventNID, stateNID) -} - -// StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateAtEvent, error) { - return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) -} - -// StateBlockNIDs implements state.RoomStateDatabase -func (d *Database) StateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) ([]types.StateBlockNIDList, error) { - return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs) -} - -// StateEntries implements state.RoomStateDatabase -func (d *Database) StateEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) -} - -// SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID( - ctx context.Context, eventID string, -) (types.StateSnapshotNID, error) { - _, stateNID, err := d.statements.selectEvent(ctx, eventID) - return stateNID, err -} - -// EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs( - ctx context.Context, eventNIDs []types.EventNID, -) (map[types.EventNID]string, error) { - return d.statements.bulkSelectEventID(ctx, eventNIDs) -} - -// GetLatestEventsForUpdate implements input.EventDatabase -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, -) (types.RoomRecentEventsUpdater, error) { - txn, err := d.db.Begin() - if err != nil { - return nil, err - } - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - var lastEventIDSent string - if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - } - return &roomRecentEventsUpdater{ - transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, - }, nil -} - -// GetTransactionEventID implements input.EventDatabase -func (d *Database) GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, -) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID) - if err == sql.ErrNoRows { - return "", nil - } - return eventID, err -} - -type roomRecentEventsUpdater struct { - transaction - d *Database - roomNID types.RoomNID - latestEvents []types.StateAtEventAndReference - lastEventIDSent string - currentStateSnapshotNID types.StateSnapshotNID -} - -// RoomVersion implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) - return -} - -// LatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { - return u.latestEvents -} - -// LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LastEventIDSent() string { - return u.lastEventIDSent -} - -// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { - return u.currentStateSnapshotNID -} - -// StorePreviousEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return err - } - } - return nil -} - -// IsReferenced implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - return true, nil - } - if err == sql.ErrNoRows { - return false, nil - } - return false, err -} - -// SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) SetLatestEvents( - roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, - currentStateSnapshotNID types.StateSnapshotNID, -) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) -} - -// HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) -} - -// MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) -} - -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) -} - -// RoomNID implements query.RoomserverQueryAPIDB -func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { - roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) - if err == sql.ErrNoRows { - return 0, nil - } - return roomNID, err -} - -// LatestEventIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, -) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { - eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return nil, 0, 0, err - } - references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - return references, currentStateSnapshotNID, depth, nil -} - -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - -// SetRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID) -} - -// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.statements.selectRoomIDFromAlias(ctx, alias) -} - -// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB -func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.statements.selectAliasesFromRoomID(ctx, roomID) -} - -// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetCreatorIDForAlias( - ctx context.Context, alias string, -) (string, error) { - return d.statements.selectCreatorIDFromAlias(ctx, alias) -} - -// RemoveRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.statements.deleteRoomAlias(ctx, alias) -} - -// StateEntriesForTuples implements state.RoomStateDatabase -func (d *Database) StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, - ) -} - -// MembershipUpdater implements input.RoomEventDatabase -func (d *Database) MembershipUpdater( - ctx context.Context, roomID, targetUserID string, - roomVersion gomatrixserverlib.RoomVersion, -) (types.MembershipUpdater, error) { - txn, err := d.db.Begin() - if err != nil { - return nil, err - } - succeeded := false - defer func() { - if !succeeded { - txn.Rollback() // nolint: errcheck - } - }() - - roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) - if err != nil { - return nil, err - } - - targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) - if err != nil { - return nil, err - } - - updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) - if err != nil { - return nil, err - } - - succeeded = true - return updater, nil -} - -type membershipUpdater struct { - transaction - d *Database - roomNID types.RoomNID - targetUserNID types.EventStateKeyNID - membership membershipState -} - -func (d *Database) membershipUpdaterTxn( - ctx context.Context, - txn *sql.Tx, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (types.MembershipUpdater, error) { - - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { - return nil, err - } - - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) - if err != nil { - return nil, err - } - - return &membershipUpdater{ - transaction{ctx, txn}, d, roomNID, targetUserNID, membership, - }, nil -} - -// IsInvite implements types.MembershipUpdater -func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite -} - -// IsJoin implements types.MembershipUpdater -func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin -} - -// IsLeave implements types.MembershipUpdater -func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan -} - -// SetToInvite implements types.MembershipUpdater -func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) - if err != nil { - return false, err - } - inserted, err := u.d.statements.insertInviteEvent( - u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), - ) - if err != nil { - return false, err - } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, - ); err != nil { - return false, err - } - } - return inserted, nil -} - -// SetToJoin implements types.MembershipUpdater -func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { - var inviteEventIDs []string - - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, err - } - - // If this is a join event update, there is no invite to update - if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return nil, err - } - } - - // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, err - } - - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], - ); err != nil { - return nil, err - } - } - - return inviteEventIDs, nil -} - -// SetToLeave implements types.MembershipUpdater -func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, err - } - inviteEventIDs, err := u.d.statements.updateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return nil, err - } - - // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, err - } - - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { - return nil, err - } - } - return inviteEventIDs, nil -} - -// GetMembership implements query.RoomserverQueryAPIDB -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) - if err != nil { - return - } - - senderMembershipEventNID, senderMembership, err := - d.statements.selectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, - ) - if err == sql.ErrNoRows { - // The user has never been a member of that room - return 0, false, nil - } else if err != nil { - return - } - - return senderMembershipEventNID, senderMembership == membershipStateJoin, nil -} - -// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB -func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, -) ([]types.EventNID, error) { - if joinOnly { - return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, - ) - } - - return d.statements.selectMembershipsFromRoom(ctx, roomNID) -} - -// EventsFromIDs implements query.RoomserverQueryAPIEventDB -func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - var nids []types.EventNID - for _, nid := range nidMap { - nids = append(nids, nid) - } - - return d.Events(ctx, nids) -} - -func (d *Database) GetRoomVersionForRoom( - ctx context.Context, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomID( - ctx, nil, roomID, - ) -} - -func (d *Database) GetRoomVersionForRoomNID( - ctx context.Context, roomNID types.RoomNID, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomNID( - ctx, nil, roomNID, - ) -} - -type transaction struct { - ctx context.Context - txn *sql.Tx -} - -// Commit implements types.Transaction -func (t *transaction) Commit() error { - return t.txn.Commit() -} - -// Rollback implements types.Transaction -func (t *transaction) Rollback() error { - return t.txn.Rollback() -} diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go index 87c1cacae..5e59ae16d 100644 --- a/roomserver/storage/postgres/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -18,6 +18,9 @@ package postgres import ( "context" "database/sql" + + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const transactionsSchema = ` @@ -51,20 +54,21 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func (s *transactionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(transactionsSchema) +func NewPostgresTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} + _, err := db.Exec(transactionsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *transactionStatements) insertTransaction( - ctx context.Context, +func (s *transactionStatements) InsertTransaction( + ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, @@ -76,7 +80,7 @@ func (s *transactionStatements) insertTransaction( return } -func (s *transactionStatements) selectTransactionEventID( +func (s *transactionStatements) SelectTransactionEventID( ctx context.Context, transactionID string, sessionID int64, diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go new file mode 100644 index 000000000..5ddf6d84d --- /dev/null +++ b/roomserver/storage/shared/membership_updater.go @@ -0,0 +1,183 @@ +package shared + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type membershipUpdater struct { + transaction + d *Database + roomNID types.RoomNID + targetUserNID types.EventStateKeyNID + membership tables.MembershipState +} + +func NewMembershipUpdater( + ctx context.Context, d *Database, roomID, targetUserID string, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, + useTxns bool, +) (types.MembershipUpdater, error) { + txn, err := d.DB.Begin() + if err != nil { + return nil, err + } + succeeded := false + defer func() { + if !succeeded { + txn.Rollback() // nolint: errcheck + } + }() + + roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) + if err != nil { + return nil, err + } + + targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) + if err != nil { + return nil, err + } + + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) + if err != nil { + return nil, err + } + + succeeded = true + if !useTxns { + txn.Commit() // nolint: errcheck + updater.transaction.txn = nil + } + return updater, nil +} + +func (d *Database) membershipUpdaterTxn( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, + targetLocal bool, +) (*membershipUpdater, error) { + + if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + return nil, err + } + + membership, err := d.MembershipTable.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + return &membershipUpdater{ + transaction{ctx, txn}, d, roomNID, targetUserNID, membership, + }, nil +} + +// IsInvite implements types.MembershipUpdater +func (u *membershipUpdater) IsInvite() bool { + return u.membership == tables.MembershipStateInvite +} + +// IsJoin implements types.MembershipUpdater +func (u *membershipUpdater) IsJoin() bool { + return u.membership == tables.MembershipStateJoin +} + +// IsLeave implements types.MembershipUpdater +func (u *membershipUpdater) IsLeave() bool { + return u.membership == tables.MembershipStateLeaveOrBan +} + +// SetToInvite implements types.MembershipUpdater +func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + if err != nil { + return false, err + } + inserted, err := u.d.InvitesTable.InsertInviteEvent( + u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return false, err + } + if u.membership != tables.MembershipStateInvite { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, + ); err != nil { + return false, err + } + } + return inserted, nil +} + +// SetToJoin implements types.MembershipUpdater +func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { + var inviteEventIDs []string + + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + } + + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != tables.MembershipStateJoin || isUpdate { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + tables.MembershipStateJoin, nIDs[eventID], + ); err != nil { + return nil, err + } + } + + return inviteEventIDs, nil +} + +// SetToLeave implements types.MembershipUpdater +func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != tables.MembershipStateLeaveOrBan { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + tables.MembershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return nil, err + } + } + return inviteEventIDs, nil +} diff --git a/roomserver/storage/shared/prepare.go b/roomserver/storage/shared/prepare.go new file mode 100644 index 000000000..65ceec1cc --- /dev/null +++ b/roomserver/storage/shared/prepare.go @@ -0,0 +1,60 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" +) + +// StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type StatementList []struct { + Statement **sql.Stmt + SQL string +} + +// Prepare the SQL for each statement in the list and assign the result to the prepared statement. +func (s StatementList) Prepare(db *sql.DB) (err error) { + for _, statement := range s { + if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { + return + } + } + return +} + +type transaction struct { + ctx context.Context + txn *sql.Tx +} + +// Commit implements types.Transaction +func (t *transaction) Commit() error { + if t.txn == nil { + // The Updater structs can operate in useTxns=false mode. The code will still call this though. + return nil + } + return t.txn.Commit() +} + +// Rollback implements types.Transaction +func (t *transaction) Rollback() error { + if t.txn == nil { + // The Updater structs can operate in useTxns=false mode. The code will still call this though. + return nil + } + return t.txn.Rollback() +} diff --git a/roomserver/storage/shared/room_recent_events_updater.go b/roomserver/storage/shared/room_recent_events_updater.go new file mode 100644 index 000000000..8131f712d --- /dev/null +++ b/roomserver/storage/shared/room_recent_events_updater.go @@ -0,0 +1,120 @@ +package shared + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type roomRecentEventsUpdater struct { + transaction + d *Database + roomNID types.RoomNID + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +func NewRoomRecentEventsUpdater(d *Database, ctx context.Context, roomNID types.RoomNID, useTxns bool) (types.RoomRecentEventsUpdater, error) { + txn, err := d.DB.Begin() + if err != nil { + return nil, err + } + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + } + if !useTxns { + txn.Commit() // nolint: errcheck + txn = nil + } + return &roomRecentEventsUpdater{ + transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// RoomVersion implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { + version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) + return +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return err + } + } + return nil +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { + err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, err +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { + return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { + return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, u.txn, eventNID) +} + +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go new file mode 100644 index 000000000..2751cc557 --- /dev/null +++ b/roomserver/storage/shared/storage.go @@ -0,0 +1,493 @@ +package shared + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + EventsTable tables.Events + EventJSONTable tables.EventJSON + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys + RoomsTable tables.Rooms + TransactionsTable tables.Transactions + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + PrevEventsTable tables.PreviousEvents + InvitesTable tables.Invites + MembershipTable tables.Membership +} + +func (d *Database) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) +} + +func (d *Database) EventStateKeys( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) +} + +func (d *Database) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) +} + +func (d *Database) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) +} + +func (d *Database) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.StateBlockTable.BulkSelectFilteredStateBlockEntries( + ctx, stateBlockNIDs, stateKeyTuples, + ) +} + +func (d *Database) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + if len(state) > 0 { + var stateBlockNID types.StateBlockNID + stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) + if err != nil { + return err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + } + stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs) + return err + }) + if err != nil { + return 0, err + } + return +} + +func (d *Database) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) +} + +func (d *Database) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) +} + +func (d *Database) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) +} + +func (d *Database) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID) + return stateNID, err +} + +func (d *Database) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) +} + +func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.EventNIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + var nids []types.EventNID + for _, nid := range nidMap { + nids = append(nids, nid) + } + + return d.Events(ctx, nids) +} + +func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { + roomNID, err := d.RoomsTable.SelectRoomNID(ctx, nil, roomID) + if err == sql.ErrNoRows { + return 0, nil + } + return roomNID, err +} + +func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { + roomNID, err = d.RoomNID(ctx, roomID) + if err != nil { + return + } + latestEvents, _, err := d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID) + if err != nil { + return + } + if len(latestEvents) == 0 { + roomNID = 0 + return + } + return +} + +func (d *Database) LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, +) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, txn, roomNID) + if err != nil { + return err + } + references, err = d.EventsTable.BulkSelectEventReference(ctx, txn, eventNIDs) + if err != nil { + return err + } + depth, err = d.EventsTable.SelectMaxEventDepth(ctx, txn, eventNIDs) + if err != nil { + return err + } + return nil + }) + return +} + +func (d *Database) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) +} + +func (d *Database) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) +} + +func (d *Database) GetRoomVersionForRoom( + ctx context.Context, roomID string, +) (gomatrixserverlib.RoomVersion, error) { + return d.RoomsTable.SelectRoomVersionForRoomID( + ctx, nil, roomID, + ) +} + +func (d *Database) GetRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + return d.RoomsTable.SelectRoomVersionForRoomNID( + ctx, roomNID, + ) +} + +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { + return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) +} + +func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { + return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) +} + +func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { + return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) +} + +func (d *Database) GetCreatorIDForAlias( + ctx context.Context, alias string, +) (string, error) { + return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) +} + +func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { + return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) +} + +func (d *Database) GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, +) (membershipEventNID types.EventNID, stillInRoom bool, err error) { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) + if err != nil { + return + } + + senderMembershipEventNID, senderMembership, err := + d.MembershipTable.SelectMembershipFromRoomAndTarget( + ctx, roomNID, requestSenderUserNID, + ) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return 0, false, nil + } else if err != nil { + return + } + + return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, nil +} + +func (d *Database) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + if joinOnly { + return d.MembershipTable.SelectMembershipsFromRoomAndMembership( + ctx, roomNID, tables.MembershipStateJoin, localOnly, + ) + } + + return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) +} + +func (d *Database) GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (senderUserIDs []types.EventStateKeyNID, err error) { + return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) +} + +func (d *Database) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + if err != nil { + return nil, err + } + results := make([]types.Event, len(eventJSONs)) + for i, eventJSON := range eventJSONs { + var roomNID types.RoomNID + var roomVersion gomatrixserverlib.RoomVersion + result := &results[i] + result.EventNID = eventJSON.EventNID + roomNID, err = d.EventsTable.SelectRoomNIDForEventNID(ctx, eventJSON.EventNID) + if err != nil { + return nil, err + } + roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID) + if err != nil { + return nil, err + } + result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( + eventJSON.EventJSON, false, roomVersion, + ) + if err != nil { + return nil, err + } + } + return results, nil +} + +func (d *Database) GetTransactionEventID( + ctx context.Context, transactionID string, + sessionID int64, userID string, +) (string, error) { + eventID, err := d.TransactionsTable.SelectTransactionEventID(ctx, transactionID, sessionID, userID) + if err == sql.ErrNoRows { + return "", nil + } + return eventID, err +} + +func (d *Database) MembershipUpdater( + ctx context.Context, roomID, targetUserID string, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, +) (types.MembershipUpdater, error) { + return NewMembershipUpdater(ctx, d, roomID, targetUserID, targetLocal, roomVersion, true) +} + +func (d *Database) GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, +) (types.RoomRecentEventsUpdater, error) { + return NewRoomRecentEventsUpdater(d, ctx, roomNID, true) +} + +func (d *Database) StoreEvent( + ctx context.Context, event gomatrixserverlib.Event, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, +) (types.RoomNID, types.StateAtEvent, error) { + var ( + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error + ) + + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + if txnAndSessionID != nil { + if err = d.TransactionsTable.InsertTransaction( + ctx, txn, txnAndSessionID.TransactionID, + txnAndSessionID.SessionID, event.Sender(), event.EventID(), + ); err != nil { + return err + } + } + + // TODO: Here we should aim to have two different code paths for new rooms + // vs existing ones. + + // Get the default room version. If the client doesn't supply a room_version + // then we will use our configured default to create the room. + // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom + // Note that the below logic depends on the m.room.create event being the + // first event that is persisted to the database when creating or joining a + // room. + var roomVersion gomatrixserverlib.RoomVersion + if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { + return err + } + + if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { + return err + } + + if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { + return err + } + + eventStateKey := event.StateKey() + // Assigned a numeric ID for the state_key if there is one present. + // Otherwise set the numeric ID for the state_key to 0. + if eventStateKey != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { + return err + } + } + + if eventNID, stateNID, err = d.EventsTable.InsertEvent( + ctx, + txn, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID()) + } + if err != nil { + return err + } + } + + if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { + return err + } + + return nil + }) + if err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateSnapshotNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil +} + +func (d *Database) assignRoomNID( + ctx context.Context, txn *sql.Tx, + roomID string, roomVersion gomatrixserverlib.RoomVersion, +) (types.RoomNID, error) { + // Check if we already have a numeric ID in the database. + roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + roomNID, err = d.RoomsTable.InsertRoomNID(ctx, txn, roomID, roomVersion) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID) + } + } + return roomNID, err +} + +func (d *Database) assignEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, +) (eventTypeNID types.EventTypeNID, err error) { + // Check if we already have a numeric ID in the database. + eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventTypeNID, err = d.EventTypesTable.InsertEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) + } + } + return +} + +func (d *Database) assignStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + // Check if we already have a numeric ID in the database. + eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventStateKeyNID, err = d.EventStateKeysTable.InsertEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) + } + } + return eventStateKeyNID, err +} + +func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( + gomatrixserverlib.RoomVersion, error, +) { + var err error + var roomVersion gomatrixserverlib.RoomVersion + // Look for m.room.create events. + if event.Type() != gomatrixserverlib.MRoomCreate { + return gomatrixserverlib.RoomVersion(""), nil + } + roomVersion = gomatrixserverlib.RoomVersionV1 + var createContent gomatrixserverlib.CreateContent + // The m.room.create event contains an optional "room_version" key in + // the event content, so we need to unmarshal that first. + if err = json.Unmarshal(event.Content(), &createContent); err != nil { + return gomatrixserverlib.RoomVersion(""), err + } + // A room version was specified in the event content? + if createContent.RoomVersion != nil { + roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion) + } + return roomVersion, err +} diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index fc661c1da..da0c448dc 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -20,7 +20,10 @@ import ( "database/sql" "strings" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -51,50 +54,46 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} s.db = db - _, err = db.Exec(eventJSONSchema) + _, err := db.Exec(eventJSONSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventJSONStatements) insertEventJSON( +func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - _, err := common.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) return err } -type eventJSONPair struct { - EventNID types.EventNID - EventJSON []byte -} - -func (s *eventJSONStatements) bulkSelectEventJSON( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, -) ([]eventJSONPair, error) { +func (s *eventJSONStatements) BulkSelectEventJSON( + ctx context.Context, eventNIDs []types.EventNID, +) ([]tables.EventJSONPair, error) { iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed") // We know that we will only get as many results as event NIDs // because of the unique constraint on event NIDs. // So we can allocate an array of the correct size now. // We might get fewer results than NIDs so we adjust the length of the slice before returning it. - results := make([]eventJSONPair, len(eventNIDs)) + results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { result := &results[i] diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index fa8fc57eb..cbea8428c 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -20,7 +20,10 @@ import ( "database/sql" "strings" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -47,14 +50,14 @@ const selectEventStateKeyNIDSQL = ` // Bulk lookup from string state key to numeric ID for that state key. // Takes an array of strings as the query parameter. -const bulkSelectEventStateKeyNIDSQL = ` +const bulkSelectEventStateKeySQL = ` SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys WHERE event_state_key IN ($1) ` // Bulk lookup from numeric ID to string state key for that state key. // Takes an array of strings as the query parameter. -const bulkSelectEventStateKeySQL = ` +const bulkSelectEventStateKeyNIDSQL = ` SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys WHERE event_state_key_nid IN ($1) ` @@ -67,56 +70,57 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{} s.db = db - _, err = db.Exec(eventStateKeysSchema) + _, err := db.Exec(eventStateKeysSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventStateKeyStatements) insertEventStateKeyNID( +func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 var err error var res sql.Result - insertStmt := txn.Stmt(s.insertEventStateKeyNIDStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { eventStateKeyNID, err = res.LastInsertId() } return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) selectEventStateKeyNID( +func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := txn.Stmt(s.selectEventStateKeyNIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( - ctx context.Context, txn *sql.Tx, eventStateKeys []string, +func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( + ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { iEventStateKeys := make([]interface{}, len(eventStateKeys)) for k, v := range eventStateKeys { iEventStateKeys[k] = v } - selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeys)), 1) + selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed") result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) for rows.Next() { var stateKey string @@ -129,20 +133,20 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( return result, nil } -func (s *eventStateKeyStatements) bulkSelectEventStateKey( - ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, +func (s *eventStateKeyStatements) BulkSelectEventStateKey( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) for k, v := range eventStateKeyNIDs { iEventStateKeyNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeyNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed") result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) for rows.Next() { var stateKey string diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 777f8be79..c9a461f99 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -20,7 +20,10 @@ import ( "database/sql" "strings" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -81,64 +84,64 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { + s := &eventTypeStatements{} s.db = db - _, err = db.Exec(eventTypesSchema) + _, err := db.Exec(eventTypesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventTypeStatements) insertEventTypeNID( +func (s *eventTypeStatements) InsertEventTypeNID( ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 var err error - insertStmt := common.TxStmt(tx, s.insertEventTypeNIDStmt) - resultStmt := common.TxStmt(tx, s.insertEventTypeNIDResultStmt) + insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt) + resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt) if _, err = insertStmt.ExecContext(ctx, eventType); err == nil { err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) } return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) selectEventTypeNID( +func (s *eventTypeStatements) SelectEventTypeNID( ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - selectStmt := common.TxStmt(tx, s.selectEventTypeNIDStmt) + selectStmt := sqlutil.TxStmt(tx, s.selectEventTypeNIDStmt) err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) bulkSelectEventTypeNID( - ctx context.Context, tx *sql.Tx, eventTypes []string, +func (s *eventTypeStatements) BulkSelectEventTypeNID( + ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { /////////////// iEventTypes := make([]interface{}, len(eventTypes)) for k, v := range eventTypes { iEventTypes[k] = v } - selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", common.QueryVariadic(len(iEventTypes)), 1) + selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventTypes)), 1) selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - selectStmt := common.TxStmt(tx, selectPrep) - rows, err := selectStmt.QueryContext(ctx, iEventTypes...) + rows, err := selectPrep.QueryContext(ctx, iEventTypes...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") result := make(map[string]types.EventTypeNID, len(eventTypes)) for rows.Next() { diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index d881fa91f..d66db4694 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -22,7 +22,10 @@ import ( "fmt" "strings" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -48,11 +51,6 @@ const insertEventSQL = ` ON CONFLICT DO NOTHING; ` -const insertEventResultSQL = ` - SELECT event_nid, state_snapshot_nid FROM roomserver_events - WHERE rowid = last_insert_rowid(); -` - const selectEventSQL = "" + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" @@ -102,7 +100,6 @@ const selectRoomNIDForEventNIDSQL = "" + type eventStatements struct { db *sql.DB insertEventStmt *sql.Stmt - insertEventResultStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt @@ -117,16 +114,16 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func (s *eventStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{} s.db = db - _, err = db.Exec(eventsSchema) + _, err := db.Exec(eventsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, - {&s.insertEventResultStmt, insertEventResultSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, @@ -139,10 +136,10 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventStatements) insertEvent( +func (s *eventStatements) InsertEvent( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, @@ -153,53 +150,55 @@ func (s *eventStatements) insertEvent( authEventNIDs []types.EventNID, depth int64, ) (types.EventNID, types.StateSnapshotNID, error) { - var eventNID int64 - var stateNID int64 - var err error - insertStmt := common.TxStmt(txn, s.insertEventStmt) - resultStmt := common.TxStmt(txn, s.insertEventResultStmt) - if _, err = insertStmt.ExecContext( + // attempt to insert: the last_row_id is the event NID + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + result, err := insertStmt.ExecContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, - ); err == nil { - err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID) + ) + if err != nil { + return 0, 0, err } - return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err + modified, err := result.RowsAffected() + if modified == 0 && err == nil { + return 0, 0, sql.ErrNoRows + } + eventNID, err := result.LastInsertId() + return types.EventNID(eventNID), 0, err } -func (s *eventStatements) selectEvent( +func (s *eventStatements) SelectEvent( ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - selectStmt := common.TxStmt(txn, s.selectEventStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectEventStmt) err := selectStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError -func (s *eventStatements) bulkSelectStateEventByID( - ctx context.Context, txn *sql.Tx, eventIDs []string, +func (s *eventStatements) BulkSelectStateEventByID( + ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - selectStmt := common.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") // We know that we will only get as many results as event IDs // because of the unique constraint on event IDs. // So we can allocate an array of the correct size now. @@ -221,7 +220,7 @@ func (s *eventStatements) bulkSelectStateEventByID( // We don't know which ones were missing because we don't return the string IDs in the query. // However it should be possible debug this by replaying queries or entries from the input kafka logs. // If this turns out to be impossible and we do need the debug information here, it would be better - // to do it as a separate query rather than slowing down/complicating the common case. + // to do it as a separate query rather than slowing down/complicating the internal case. return nil, types.MissingEventError( fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)), ) @@ -232,27 +231,25 @@ func (s *eventStatements) bulkSelectStateEventByID( // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. -func (s *eventStatements) bulkSelectStateAtEventByID( - ctx context.Context, txn *sql.Tx, eventIDs []string, +func (s *eventStatements) BulkSelectStateAtEventByID( + ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - - selectStmt := common.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventByID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventByID: rows.close() failed") results := make([]types.StateAtEvent, len(eventIDs)) i := 0 for ; rows.Next(); i++ { @@ -279,18 +276,17 @@ func (s *eventStatements) bulkSelectStateAtEventByID( return results, err } -func (s *eventStatements) updateEventState( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, +func (s *eventStatements) UpdateEventState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - updateStmt := common.TxStmt(txn, s.updateEventStateStmt) - _, err := updateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) return err } -func (s *eventStatements) selectEventSentToOutput( +func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { - selectStmt := common.TxStmt(txn, s.selectEventSentToOutputStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt) err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) //err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) if err != nil { @@ -298,22 +294,22 @@ func (s *eventStatements) selectEventSentToOutput( return } -func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - updateStmt := common.TxStmt(txn, s.updateEventSentToOutputStmt) +func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { + updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := updateStmt.ExecContext(ctx, int64(eventNID)) //_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID)) return err } -func (s *eventStatements) selectEventID( +func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { - selectStmt := common.TxStmt(txn, s.selectEventIDStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectEventIDStmt) err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) return } -func (s *eventStatements) bulkSelectStateAtEventAndReference( +func (s *eventStatements) BulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { /////////////// @@ -321,14 +317,14 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) ////////////// rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") results := make([]types.StateAtEventAndReference, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { @@ -359,7 +355,7 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( return results, nil } -func (s *eventStatements) bulkSelectEventReference( +func (s *eventStatements) BulkSelectEventReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { /////////////// @@ -367,19 +363,19 @@ func (s *eventStatements) bulkSelectEventReference( for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - selectStmt := common.TxStmt(txn, selectPrep) + selectStmt := sqlutil.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { @@ -395,25 +391,24 @@ func (s *eventStatements) bulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { /////////////// iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - selectStmt := common.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") results := make(map[types.EventNID]string, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { @@ -432,25 +427,23 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - - selectStmt := common.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") results := make(map[string]types.EventNID, len(eventIDs)) for rows.Next() { var eventID string @@ -463,13 +456,13 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e return results, nil } -func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { +func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { var result int64 iEventIDs := make([]interface{}, len(eventNIDs)) for i, v := range eventNIDs { iEventIDs[i] = v } - sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) + sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) err := txn.QueryRowContext(ctx, sqlStr, iEventIDs...).Scan(&result) if err != nil { return 0, err @@ -477,11 +470,10 @@ func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, return result, nil } -func (s *eventStatements) selectRoomNIDForEventNID( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +func (s *eventStatements) SelectRoomNIDForEventNID( + ctx context.Context, eventNID types.EventNID, ) (roomNID types.RoomNID, err error) { - selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) - err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) + err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) return } diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 0ab3e6f36..21745d1b0 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -19,7 +19,10 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -66,28 +69,28 @@ type inviteStatements struct { selectInvitesAboutToRetireStmt *sql.Stmt } -func (s *inviteStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteSchema) +func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} + _, err := db.Exec(inviteSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, {&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *inviteStatements) insertInviteEvent( +func (s *inviteStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { - stmt := common.TxStmt(txn, s.insertInviteEventStmt) - defer stmt.Close() // nolint: errcheck + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) result, err := stmt.ExecContext( ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ) @@ -101,12 +104,12 @@ func (s *inviteStatements) insertInviteEvent( return count != 0, nil } -func (s *inviteStatements) updateInviteRetired( +func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { // gather all the event IDs we will retire - stmt := txn.Stmt(s.selectInvitesAboutToRetireStmt) + stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) if err != nil { return nil, err @@ -121,13 +124,13 @@ func (s *inviteStatements) updateInviteRetired( } // now retire the invites - stmt = txn.Stmt(s.updateInviteRetiredStmt) + stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) return } // selectInviteActiveForUserInRoom returns a list of sender state key NIDs -func (s *inviteStatements) selectInviteActiveForUserInRoom( +func (s *inviteStatements) SelectInviteActiveForUserInRoom( ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, error) { @@ -137,7 +140,7 @@ func (s *inviteStatements) selectInviteActiveForUserInRoom( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") var result []types.EventStateKeyNID for rows.Next() { var senderUserNID int64 diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 7ae28e4b8..6f0d763e7 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -19,18 +19,13 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -type membershipState int64 - -const ( - membershipStateLeaveOrBan membershipState = 1 - membershipStateInvite membershipState = 2 - membershipStateJoin membershipState = 3 -) - const membershipSchema = ` CREATE TABLE IF NOT EXISTS roomserver_membership ( room_nid INTEGER NOT NULL, @@ -38,6 +33,7 @@ const membershipSchema = ` sender_nid INTEGER NOT NULL DEFAULT 0, membership_nid INTEGER NOT NULL DEFAULT 1, event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` @@ -45,8 +41,8 @@ const membershipSchema = ` // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid)" + - " VALUES ($1, $2)" + + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + + " VALUES ($1, $2, $3)" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + @@ -57,10 +53,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" +const selectLocalMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + + " AND target_local = true" + const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" +const selectLocalMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + + " AND target_local = true" + const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" @@ -70,71 +76,81 @@ const updateMembershipSQL = "" + " WHERE room_nid = $4 AND target_nid = $5" type membershipStatements struct { - insertMembershipStmt *sql.Stmt - selectMembershipForUpdateStmt *sql.Stmt - selectMembershipFromRoomAndTargetStmt *sql.Stmt - selectMembershipsFromRoomAndMembershipStmt *sql.Stmt - selectMembershipsFromRoomStmt *sql.Stmt - updateMembershipStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + selectLocalMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt } -func (s *membershipStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(membershipSchema) +func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} + _, err := db.Exec(membershipSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *membershipStatements) insertMembership( +func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + localTarget bool, ) error { - stmt := common.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) return err } -func (s *membershipStatements) selectMembershipForUpdate( +func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (membership membershipState, err error) { - stmt := common.TxStmt(txn, s.selectMembershipForUpdateStmt) +) (membership tables.MembershipState, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt) err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership) return } -func (s *membershipStatements) selectMembershipFromRoomAndTarget( - ctx context.Context, txn *sql.Tx, +func (s *membershipStatements) SelectMembershipFromRoomAndTarget( + ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership membershipState, err error) { - selectStmt := common.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) - err = selectStmt.QueryRowContext( +) (eventNID types.EventNID, membership tables.MembershipState, err error) { + err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) return } -func (s *membershipStatements) selectMembershipsFromRoom( - ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, +func (s *membershipStatements) SelectMembershipsFromRoom( + ctx context.Context, + roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt) + var selectStmt *sql.Stmt + if localOnly { + selectStmt = s.selectLocalMembershipsFromRoomStmt + } else { + selectStmt = s.selectMembershipsFromRoomStmt + } rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed") for rows.Next() { var eNID types.EventNID @@ -145,16 +161,22 @@ func (s *membershipStatements) selectMembershipsFromRoom( } return } -func (s *membershipStatements) selectMembershipsFromRoomAndMembership( - ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, membership membershipState, + +func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( + ctx context.Context, + roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + var stmt *sql.Stmt + if localOnly { + stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt + } else { + stmt = s.selectMembershipsFromRoomAndMembershipStmt + } rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed") for rows.Next() { var eNID types.EventNID @@ -166,13 +188,13 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( return } -func (s *membershipStatements) updateMembership( +func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership membershipState, + senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { - stmt := common.TxStmt(txn, s.updateMembershipStmt) + stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) _, err := stmt.ExecContext( ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, ) diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 9ed64a38e..549aecfb7 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -19,7 +19,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -55,26 +57,27 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func (s *previousEventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(previousEventSchema) +func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { + s := &previousEventStatements{} + _, err := db.Exec(previousEventSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *previousEventStatements) insertPreviousEvent( +func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - stmt := common.TxStmt(txn, s.insertPreviousEventStmt) + stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err := stmt.ExecContext( ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), ) @@ -83,10 +86,10 @@ func (s *previousEventStatements) insertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. -func (s *previousEventStatements) selectPreviousEventExists( +func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { var ok int64 - stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt) + stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index d29833918..da5f9161a 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -19,7 +19,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const roomAliasesSchema = ` @@ -60,50 +62,48 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomAliasesSchema) +func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { + s := &roomAliasesStatements{} + _, err := db.Exec(roomAliasesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *roomAliasesStatements) insertRoomAlias( - ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, +func (s *roomAliasesStatements) InsertRoomAlias( + ctx context.Context, alias string, roomID string, creatorUserID string, ) (err error) { - insertStmt := common.TxStmt(txn, s.insertRoomAliasStmt) - _, err = insertStmt.ExecContext(ctx, alias, roomID, creatorUserID) + _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) return } -func (s *roomAliasesStatements) selectRoomIDFromAlias( - ctx context.Context, txn *sql.Tx, alias string, +func (s *roomAliasesStatements) SelectRoomIDFromAlias( + ctx context.Context, alias string, ) (roomID string, err error) { - selectStmt := common.TxStmt(txn, s.selectRoomIDFromAliasStmt) - err = selectStmt.QueryRowContext(ctx, alias).Scan(&roomID) + err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } return } -func (s *roomAliasesStatements) selectAliasesFromRoomID( - ctx context.Context, txn *sql.Tx, roomID string, +func (s *roomAliasesStatements) SelectAliasesFromRoomID( + ctx context.Context, roomID string, ) (aliases []string, err error) { aliases = []string{} - selectStmt := common.TxStmt(txn, s.selectAliasesFromRoomIDStmt) - rows, err := selectStmt.QueryContext(ctx, roomID) + rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") for rows.Next() { var alias string @@ -117,21 +117,19 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( return } -func (s *roomAliasesStatements) selectCreatorIDFromAlias( - ctx context.Context, txn *sql.Tx, alias string, +func (s *roomAliasesStatements) SelectCreatorIDFromAlias( + ctx context.Context, alias string, ) (creatorID string, err error) { - selectStmt := common.TxStmt(txn, s.selectCreatorIDFromAliasStmt) - err = selectStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } return } -func (s *roomAliasesStatements) deleteRoomAlias( - ctx context.Context, txn *sql.Tx, alias string, +func (s *roomAliasesStatements) DeleteRoomAlias( + ctx context.Context, alias string, ) (err error) { - deleteStmt := common.TxStmt(txn, s.deleteRoomAliasStmt) - _, err = deleteStmt.ExecContext(ctx, alias) + _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) return } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 427eeeb70..ab695c5d2 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -21,7 +21,9 @@ import ( "encoding/json" "errors" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -71,12 +73,13 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomsSchema) +func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { + s := &roomStatements{} + _, err := db.Exec(roomsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, @@ -84,38 +87,38 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *roomStatements) insertRoomNID( +func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { var err error - insertStmt := common.TxStmt(txn, s.insertRoomNIDStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil { - return s.selectRoomNID(ctx, txn, roomID) + return s.SelectRoomNID(ctx, txn, roomID) } else { return types.RoomNID(0), err } } -func (s *roomStatements) selectRoomNID( +func (s *roomStatements) SelectRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { var roomNID int64 - stmt := common.TxStmt(txn, s.selectRoomNIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } -func (s *roomStatements) selectLatestEventNIDs( +func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { var eventNIDs []types.EventNID var nidsJSON string var stateSnapshotNID int64 - stmt := common.TxStmt(txn, s.selectLatestEventNIDsStmt) + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &stateSnapshotNID) if err != nil { return nil, 0, err @@ -126,14 +129,14 @@ func (s *roomStatements) selectLatestEventNIDs( return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) selectLatestEventsNIDsForUpdate( +func (s *roomStatements) SelectLatestEventsNIDsForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { var eventNIDs []types.EventNID var nidsJSON string var lastEventSentNID int64 var stateSnapshotNID int64 - stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &lastEventSentNID, &stateSnapshotNID) if err != nil { return nil, 0, 0, err @@ -144,7 +147,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate( return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) updateLatestEventNIDs( +func (s *roomStatements) UpdateLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, @@ -152,7 +155,7 @@ func (s *roomStatements) updateLatestEventNIDs( lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt) + stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) _, err := stmt.ExecContext( ctx, eventNIDsAsArray(eventNIDs), @@ -163,11 +166,11 @@ func (s *roomStatements) updateLatestEventNIDs( return err } -func (s *roomStatements) selectRoomVersionForRoomID( +func (s *roomStatements) SelectRoomVersionForRoomID( ctx context.Context, txn *sql.Tx, roomID string, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := common.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") @@ -175,12 +178,11 @@ func (s *roomStatements) selectRoomVersionForRoomID( return roomVersion, err } -func (s *roomStatements) selectRoomVersionForRoomNID( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +func (s *roomStatements) SelectRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) - err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") } diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go deleted file mode 100644 index 0d49432b8..000000000 --- a/roomserver/storage/sqlite3/sql.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "database/sql" -) - -type statements struct { - eventTypeStatements - eventStateKeyStatements - roomStatements - eventStatements - eventJSONStatements - stateSnapshotStatements - stateBlockStatements - previousEventStatements - roomAliasesStatements - inviteStatements - membershipStatements - transactionStatements -} - -func (s *statements) prepare(db *sql.DB) error { - var err error - - for _, prepare := range []func(db *sql.DB) error{ - s.eventTypeStatements.prepare, - s.eventStateKeyStatements.prepare, - s.roomStatements.prepare, - s.eventStatements.prepare, - s.eventJSONStatements.prepare, - s.stateSnapshotStatements.prepare, - s.stateBlockStatements.prepare, - s.previousEventStatements.prepare, - s.roomAliasesStatements.prepare, - s.inviteStatements.prepare, - s.membershipStatements.prepare, - s.transactionStatements.prepare, - } { - if err = prepare(db); err != nil { - return err - } - } - - return nil -} diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index cc7c75733..c058c783a 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -22,7 +22,10 @@ import ( "sort" "strings" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -77,22 +80,23 @@ type stateBlockStatements struct { bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { +func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { + s := &stateBlockStatements{} s.db = db - _, err = db.Exec(stateDataSchema) + _, err := db.Exec(stateDataSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *stateBlockStatements) bulkInsertStateData( +func (s *stateBlockStatements) BulkInsertStateData( ctx context.Context, txn *sql.Tx, entries []types.StateEntry, ) (types.StateBlockNID, error) { @@ -120,24 +124,23 @@ func (s *stateBlockStatements) bulkInsertStateData( return stateBlockNID, nil } -func (s *stateBlockStatements) bulkSelectStateBlockEntries( - ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, +func (s *stateBlockStatements) BulkSelectStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { nids := make([]interface{}, len(stateBlockNIDs)) for k, v := range stateBlockNIDs { nids[k] = v } - selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(nids)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - selectStmt := common.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") results := make([]types.StateEntryList, len(stateBlockNIDs)) // current is a pointer to the StateEntryList to append the state entries to. @@ -174,8 +177,8 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( return results, nil } -func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( - ctx context.Context, txn *sql.Tx, // nolint: unparam +func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { @@ -184,9 +187,9 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($2)", common.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($3)", common.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) + sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) var params []interface{} for _, val := range stateBlockNIDs { @@ -207,7 +210,7 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") var results []types.StateEntryList var current types.StateEntryList diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index f367a779b..d077b6171 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -22,7 +22,10 @@ import ( "fmt" "strings" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -51,20 +54,21 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { +func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { + s := &stateSnapshotStatements{} s.db = db - _, err = db.Exec(stateSnapshotSchema) + _, err := db.Exec(stateSnapshotSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *stateSnapshotStatements) insertState( +func (s *stateSnapshotStatements) InsertState( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ) (stateNID types.StateSnapshotNID, err error) { stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) @@ -82,15 +86,15 @@ func (s *stateSnapshotStatements) insertState( return } -func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( - ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, +func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]interface{}, len(stateNIDs)) for k, v := range stateNIDs { nids[k] = v } - selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", common.QueryVariadic(len(nids)), 1) - selectStmt, err := txn.Prepare(selectOrig) + selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } @@ -99,7 +103,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed") results := make([]types.StateBlockNIDList, len(stateNIDs)) i := 0 for ; rows.Next(); i++ { diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 444a8fdd5..8e9352192 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -18,14 +18,10 @@ package sqlite3 import ( "context" "database/sql" - "encoding/json" - "errors" - "net/url" "github.com/matrix-org/dendrite/internal/sqlutil" - - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" _ "github.com/mattn/go-sqlite3" @@ -33,26 +29,28 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { - statements statements - db *sql.DB + shared.Database + events tables.Events + eventJSON tables.EventJSON + eventTypes tables.EventTypes + eventStateKeys tables.EventStateKeys + rooms tables.Rooms + transactions tables.Transactions + prevEvents tables.PreviousEvents + invites tables.Invites + membership tables.Membership + db *sql.DB } -// Open a postgres database. +// Open a sqlite database. +// nolint: gocyclo func Open(dataSourceName string) (*Database, error) { var d Database - uri, err := url.Parse(dataSourceName) + cs, err := sqlutil.ParseFileURI(dataSourceName) if err != nil { return nil, err } - var cs string - if uri.Opaque != "" { // file:filename.db - cs = uri.Opaque - } else if uri.Path != "" { // file:///path/to/filename.db - cs = uri.Path - } else { - return nil, errors.New("no filename or path in connect string") - } - if d.db, err = sqlutil.Open(common.SQLiteDriverName(), cs); err != nil { + if d.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } //d.db.Exec("PRAGMA journal_mode=WAL;") @@ -63,886 +61,94 @@ func Open(dataSourceName string) (*Database, error) { // acquire the global mutex and never unlock it because it is waiting for a connection // which it will never obtain. d.db.SetMaxOpenConns(20) - if err = d.statements.prepare(d.db); err != nil { + + d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) + if err != nil { return nil, err } + d.eventTypes, err = NewSqliteEventTypesTable(d.db) + if err != nil { + return nil, err + } + d.eventJSON, err = NewSqliteEventJSONTable(d.db) + if err != nil { + return nil, err + } + d.events, err = NewSqliteEventsTable(d.db) + if err != nil { + return nil, err + } + d.rooms, err = NewSqliteRoomsTable(d.db) + if err != nil { + return nil, err + } + d.transactions, err = NewSqliteTransactionsTable(d.db) + if err != nil { + return nil, err + } + stateBlock, err := NewSqliteStateBlockTable(d.db) + if err != nil { + return nil, err + } + stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) + if err != nil { + return nil, err + } + d.prevEvents, err = NewSqlitePrevEventsTable(d.db) + if err != nil { + return nil, err + } + roomAliases, err := NewSqliteRoomAliasesTable(d.db) + if err != nil { + return nil, err + } + d.invites, err = NewSqliteInvitesTable(d.db) + if err != nil { + return nil, err + } + d.membership, err = NewSqliteMembershipTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + EventsTable: d.events, + EventTypesTable: d.eventTypes, + EventStateKeysTable: d.eventStateKeys, + EventJSONTable: d.eventJSON, + RoomsTable: d.rooms, + TransactionsTable: d.transactions, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: d.prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: d.invites, + MembershipTable: d.membership, + } return &d, nil } -// StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent( - ctx context.Context, event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, -) (types.RoomNID, types.StateAtEvent, error) { - var ( - roomNID types.RoomNID - eventTypeNID types.EventTypeNID - eventStateKeyNID types.EventStateKeyNID - eventNID types.EventNID - stateNID types.StateSnapshotNID - err error - ) - - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - if txnAndSessionID != nil { - if err = d.statements.insertTransaction( - ctx, txn, txnAndSessionID.TransactionID, - txnAndSessionID.SessionID, event.Sender(), event.EventID(), - ); err != nil { - return err - } - } - - // TODO: Here we should aim to have two different code paths for new rooms - // vs existing ones. - - // Get the default room version. If the client doesn't supply a room_version - // then we will use our configured default to create the room. - // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom - // Note that the below logic depends on the m.room.create event being the - // first event that is persisted to the database when creating or joining a - // room. - var roomVersion gomatrixserverlib.RoomVersion - if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return err - } - - if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { - return err - } - - if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { - return err - } - - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { - return err - } - } - - if eventNID, stateNID, err = d.statements.insertEvent( - ctx, - txn, - roomNID, - eventTypeNID, - eventStateKeyNID, - event.EventID(), - event.EventReference().EventSHA256, - authEventNIDs, - event.Depth(), - ); err != nil { - if err == sql.ErrNoRows { - // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID()) - } - if err != nil { - return err - } - } - - if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { - return err - } - - return nil - }) - if err != nil { - return 0, types.StateAtEvent{}, err - } - - return roomNID, types.StateAtEvent{ - BeforeStateSnapshotNID: stateNID, - StateEntry: types.StateEntry{ - StateKeyTuple: types.StateKeyTuple{ - EventTypeNID: eventTypeNID, - EventStateKeyNID: eventStateKeyNID, - }, - EventNID: eventNID, - }, - }, nil -} - -func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( - gomatrixserverlib.RoomVersion, error, -) { - var err error - var roomVersion gomatrixserverlib.RoomVersion - // Look for m.room.create events. - if event.Type() != gomatrixserverlib.MRoomCreate { - return gomatrixserverlib.RoomVersion(""), nil - } - roomVersion = gomatrixserverlib.RoomVersionV1 - var createContent gomatrixserverlib.CreateContent - // The m.room.create event contains an optional "room_version" key in - // the event content, so we need to unmarshal that first. - if err = json.Unmarshal(event.Content(), &createContent); err != nil { - return gomatrixserverlib.RoomVersion(""), err - } - // A room version was specified in the event content? - if createContent.RoomVersion != nil { - roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion) - } - return roomVersion, err -} - -func (d *Database) assignRoomNID( - ctx context.Context, txn *sql.Tx, - roomID string, roomVersion gomatrixserverlib.RoomVersion, -) (roomNID types.RoomNID, err error) { - // Check if we already have a numeric ID in the database. - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) - if err == nil { - // Now get the numeric ID back out of the database - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - } - } - return -} - -func (d *Database) assignEventTypeNID( - ctx context.Context, txn *sql.Tx, eventType string, -) (eventTypeNID types.EventTypeNID, err error) { - // Check if we already have a numeric ID in the database. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) - } - } - return -} - -func (d *Database) assignStateKeyNID( - ctx context.Context, txn *sql.Tx, eventStateKey string, -) (eventStateKeyNID types.EventStateKeyNID, err error) { - // Check if we already have a numeric ID in the database. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - } - } - return -} - -// StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) (se []types.StateEntry, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs) - return err - }) - return -} - -// EventTypeNIDs implements state.RoomStateDatabase -func (d *Database) EventTypeNIDs( - ctx context.Context, eventTypes []string, -) (etnids map[string]types.EventTypeNID, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes) - return err - }) - return -} - -// EventStateKeyNIDs implements state.RoomStateDatabase -func (d *Database) EventStateKeyNIDs( - ctx context.Context, eventStateKeys []string, -) (esknids map[string]types.EventStateKeyNID, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) - return err - }) - return -} - -// EventStateKeys implements query.RoomserverQueryAPIDatabase -func (d *Database) EventStateKeys( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, -) (out map[types.EventStateKeyNID]string, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs) - return err - }) - return -} - -// EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs( - ctx context.Context, eventIDs []string, -) (out map[string]types.EventNID, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs) - return err - }) - return -} - -// Events implements input.EventDatabase -func (d *Database) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - var eventJSONs []eventJSONPair - var err error - var results []types.Event - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) - if err != nil || len(eventJSONs) == 0 { - return nil - } - results = make([]types.Event, len(eventJSONs)) - for i, eventJSON := range eventJSONs { - var roomNID types.RoomNID - var roomVersion gomatrixserverlib.RoomVersion - result := &results[i] - result.EventNID = eventJSON.EventNID - roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventJSON.EventNID) - if err != nil { - return err - } - roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, txn, roomNID) - if err != nil { - return err - } - result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( - eventJSON.EventJSON, false, roomVersion, - ) - if err != nil { - return nil - } - } - return nil - }) - if err != nil { - return []types.Event{}, err - } - return results, nil -} - -// AddState implements input.EventDatabase -func (d *Database) AddState( - ctx context.Context, - roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, - state []types.StateEntry, -) (stateNID types.StateSnapshotNID, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - if len(state) > 0 { - var stateBlockNID types.StateBlockNID - stateBlockNID, err = d.statements.bulkInsertStateData(ctx, txn, state) - if err != nil { - return err - } - stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) - } - stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs) - return err - }) - if err != nil { - return 0, err - } - return -} - -// SetState implements input.EventDatabase -func (d *Database) SetState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, -) error { - e := common.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.statements.updateEventState(ctx, txn, eventNID, stateNID) - }) - return e -} - -// StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs( - ctx context.Context, eventIDs []string, -) (se []types.StateAtEvent, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs) - return err - }) - return -} - -// StateBlockNIDs implements state.RoomStateDatabase -func (d *Database) StateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) (sl []types.StateBlockNIDList, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs) - return err - }) - return -} - -// StateEntries implements state.RoomStateDatabase -func (d *Database) StateEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) (sel []types.StateEntryList, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) - return err - }) - return -} - -// SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID( - ctx context.Context, eventID string, -) (stateNID types.StateSnapshotNID, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - _, stateNID, err = d.statements.selectEvent(ctx, txn, eventID) - return err - }) - return -} - -// EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs( - ctx context.Context, eventNIDs []types.EventNID, -) (out map[types.EventNID]string, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs) - return err - }) - return -} - -// GetLatestEventsForUpdate implements input.EventDatabase func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, ) (types.RoomRecentEventsUpdater, error) { - txn, err := d.db.Begin() - if err != nil { - return nil, err - } - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - var lastEventIDSent string - if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - } - - // FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get - // 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here) - // so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so - // we fail fast if someone tries to use the underlying txn object. - err = txn.Commit() - if err != nil { - return nil, err - } - return &roomRecentEventsUpdater{ - transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, - }, nil + // TODO: Do not use transactions. We should be holding open this transaction but we cannot have + // multiple write transactions on sqlite. The code will perform additional + // write transactions independent of this one which will consistently cause + // 'database is locked' errors. As sqlite doesn't support multi-process on the + // same DB anyway, and we only execute updates sequentially, the only worries + // are for rolling back when things go wrong. (atomicity) + return shared.NewRoomRecentEventsUpdater(&d.Database, ctx, roomNID, false) } -// GetTransactionEventID implements input.EventDatabase -func (d *Database) GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, -) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID) - if err == sql.ErrNoRows { - return "", nil - } - return eventID, err -} - -type roomRecentEventsUpdater struct { - transaction - d *Database - roomNID types.RoomNID - latestEvents []types.StateAtEventAndReference - lastEventIDSent string - currentStateSnapshotNID types.StateSnapshotNID -} - -// RoomVersion implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) - return -} - -// LatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { - return u.latestEvents -} - -// LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LastEventIDSent() string { - return u.lastEventIDSent -} - -// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { - return u.currentStateSnapshotNID -} - -// StorePreviousEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - for _, ref := range previousEventReferences { - if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return err - } - } - return nil - }) - return err -} - -// IsReferenced implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) { - err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - res = true - err = nil - } - if err == sql.ErrNoRows { - res = false - err = nil - } - return err - }) - return -} - -// SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) SetLatestEvents( - roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, - currentStateSnapshotNID types.StateSnapshotNID, -) error { - err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) - }) - return err -} - -// HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) { - err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID) - return err - }) - return -} - -// MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID) - }) - return err -} - -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) { - err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID) - return err - }) - return -} - -// RoomNID implements query.RoomserverQueryAPIDB -func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - roomNID = 0 - err = nil - } - return err - }) - return -} - -// LatestEventIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, -) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - var eventNIDs []types.EventNID - eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID) - if err != nil { - return err - } - references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs) - if err != nil { - return err - } - depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs) - if err != nil { - return err - } - return nil - }) - return -} - -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - -// SetRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID) -} - -// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.statements.selectRoomIDFromAlias(ctx, nil, alias) -} - -// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB -func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.statements.selectAliasesFromRoomID(ctx, nil, roomID) -} - -// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetCreatorIDForAlias( - ctx context.Context, alias string, -) (string, error) { - return d.statements.selectCreatorIDFromAlias(ctx, nil, alias) -} - -// RemoveRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.statements.deleteRoomAlias(ctx, nil, alias) -} - -// StateEntriesForTuples implements state.RoomStateDatabase -func (d *Database) StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectFilteredStateBlockEntries( - ctx, nil, stateBlockNIDs, stateKeyTuples, - ) -} - -// MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, - roomVersion gomatrixserverlib.RoomVersion, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, ) (updater types.MembershipUpdater, err error) { - var txn *sql.Tx - txn, err = d.db.Begin() - if err != nil { - return nil, err - } - succeeded := false - defer func() { - if !succeeded { - txn.Rollback() // nolint: errcheck - } else { - // TODO: We should be holding open this transaction but we cannot have - // multiple write transactions on sqlite. The code will perform additional - // write transactions independent of this one which will consistently cause - // 'database is locked' errors. For now, we'll break up the transaction and - // hope we don't race too catastrophically. Long term, we should be able to - // thread in txn objects where appropriate (either at the interface level or - // bring matrix business logic into the storage layer). - txerr := txn.Commit() - if err == nil && txerr != nil { - err = txerr - } - } - }() - - roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) - if err != nil { - return nil, err - } - - targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) - if err != nil { - return nil, err - } - - updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) - if err != nil { - return nil, err - } - - succeeded = true - return updater, nil -} - -type membershipUpdater struct { - transaction - d *Database - roomNID types.RoomNID - targetUserNID types.EventStateKeyNID - membership membershipState -} - -func (d *Database) membershipUpdaterTxn( - ctx context.Context, - txn *sql.Tx, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (types.MembershipUpdater, error) { - - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { - return nil, err - } - - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) - if err != nil { - return nil, err - } - - return &membershipUpdater{ - // purposefully set the txn to nil so if we try to use it we panic and fail fast - transaction{ctx, nil}, d, roomNID, targetUserNID, membership, - }, nil -} - -// IsInvite implements types.MembershipUpdater -func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite -} - -// IsJoin implements types.MembershipUpdater -func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin -} - -// IsLeave implements types.MembershipUpdater -func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan -} - -// SetToInvite implements types.MembershipUpdater -func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) { - err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender()) - if err != nil { - return err - } - inserted, err = u.d.statements.insertInviteEvent( - u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), - ) - if err != nil { - return err - } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, - ); err != nil { - return err - } - } - return nil - }) - return -} - -// SetToJoin implements types.MembershipUpdater -func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) { - err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) - if err != nil { - return err - } - - // If this is a join event update, there is no invite to update - if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( - u.ctx, txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return err - } - } - - // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return err - } - - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( - u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], - ); err != nil { - return err - } - } - return nil - }) - - return -} - -// SetToLeave implements types.MembershipUpdater -func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) { - err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) - if err != nil { - return err - } - inviteEventIDs, err = u.d.statements.updateInviteRetired( - u.ctx, txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return err - } - - // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return err - } - - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( - u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { - return err - } - } - return nil - }) - return -} - -// GetMembership implements query.RoomserverQueryAPIDB -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID) - if err != nil { - return err - } - - membershipEventNID, _, err = - d.statements.selectMembershipFromRoomAndTarget( - ctx, txn, roomNID, requestSenderUserNID, - ) - if err == sql.ErrNoRows { - // The user has never been a member of that room - return nil - } - if err != nil { - return err - } - stillInRoom = true - return nil - }) - - return -} - -// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB -func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, -) (eventNIDs []types.EventNID, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - if joinOnly { - eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( - ctx, txn, roomNID, membershipStateJoin, - ) - return nil - } - - eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID) - return nil - }) - return -} - -// EventsFromIDs implements query.RoomserverQueryAPIEventDB -func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - var nids []types.EventNID - for _, nid := range nidMap { - nids = append(nids, nid) - } - - return d.Events(ctx, nids) -} - -func (d *Database) GetRoomVersionForRoom( - ctx context.Context, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomID( - ctx, nil, roomID, - ) -} - -func (d *Database) GetRoomVersionForRoomNID( - ctx context.Context, roomNID types.RoomNID, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomNID( - ctx, nil, roomNID, - ) -} - -type transaction struct { - ctx context.Context - txn *sql.Tx -} - -// Commit implements types.Transaction -func (t *transaction) Commit() error { - if t.txn == nil { - return nil - } - return t.txn.Commit() -} - -// Rollback implements types.Transaction -func (t *transaction) Rollback() error { - if t.txn == nil { - return nil - } - return t.txn.Rollback() + // TODO: Do not use transactions. We should be holding open this transaction but we cannot have + // multiple write transactions on sqlite. The code will perform additional + // write transactions independent of this one which will consistently cause + // 'database is locked' errors. As sqlite doesn't support multi-process on the + // same DB anyway, and we only execute updates sequentially, the only worries + // are for rolling back when things go wrong. (atomicity) + return shared.NewMembershipUpdater(ctx, &d.Database, roomID, targetUserID, targetLocal, roomVersion, false) } diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 7740e5f07..1e8de1ca8 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -19,7 +19,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const transactionsSchema = ` @@ -46,40 +48,40 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func (s *transactionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(transactionsSchema) +func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} + _, err := db.Exec(transactionsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *transactionStatements) insertTransaction( +func (s *transactionStatements) InsertTransaction( ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, eventID string, ) (err error) { - stmt := common.TxStmt(txn, s.insertTransactionStmt) + stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) _, err = stmt.ExecContext( ctx, transactionID, sessionID, userID, eventID, ) return } -func (s *transactionStatements) selectTransactionEventID( - ctx context.Context, txn *sql.Tx, +func (s *transactionStatements) SelectTransactionEventID( + ctx context.Context, transactionID string, sessionID int64, userID string, ) (eventID string, err error) { - stmt := common.TxStmt(txn, s.selectTransactionEventIDStmt) - err = stmt.QueryRowContext( + err = s.selectTransactionEventIDStmt.QueryRowContext( ctx, transactionID, sessionID, userID, ).Scan(&eventID) return diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 7b9109aa0..d7367e4c7 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -19,22 +19,23 @@ package storage import ( "net/url" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" ) -// NewPublicRoomsServerDatabase opens a database connection. -func Open(dataSourceName string) (Database, error) { +// Open opens a database connection. +func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.Open(dataSourceName) + return postgres.Open(dataSourceName, dbProperties) } switch uri.Scheme { case "postgres": - return postgres.Open(dataSourceName) + return postgres.Open(dataSourceName, dbProperties) case "file": return sqlite3.Open(dataSourceName) default: - return postgres.Open(dataSourceName) + return postgres.Open(dataSourceName, dbProperties) } } diff --git a/roomserver/storage/storage_wasm.go b/roomserver/storage/storage_wasm.go index d7fc352e8..78405b20e 100644 --- a/roomserver/storage/storage_wasm.go +++ b/roomserver/storage/storage_wasm.go @@ -18,11 +18,15 @@ import ( "fmt" "net/url" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" ) // NewPublicRoomsServerDatabase opens a database connection. -func Open(dataSourceName string) (Database, error) { +func Open( + dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam +) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return nil, fmt.Errorf("Cannot use postgres implementation") diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go new file mode 100644 index 000000000..11cff8a8b --- /dev/null +++ b/roomserver/storage/tables/interface.go @@ -0,0 +1,122 @@ +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type EventJSONPair struct { + EventNID types.EventNID + EventJSON []byte +} + +type EventJSON interface { + InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error + BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) +} + +type EventTypes interface { + InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) + SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) + BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) +} + +type EventStateKeys interface { + InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) + SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) + BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) +} + +type Events interface { + InsertEvent(c context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64) (types.EventNID, types.StateSnapshotNID, error) + SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) + // bulkSelectStateEventByID lookups a list of state events by event ID. + // If any of the requested events are missing from the database it returns a types.MissingEventError + BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. + // If any of the requested events are missing from the database it returns a types.MissingEventError. + // If we do not have the state for any of the requested events it returns a types.MissingEventError. + BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + UpdateEventState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) + UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error + SelectEventID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) + BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) + BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) + // BulkSelectEventID returns a map from numeric event ID to string event ID. + BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. + // If an event ID is not in the database then it is omitted from the map. + BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) + SelectRoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (roomNID types.RoomNID, err error) +} + +type Rooms interface { + InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error) + SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) + SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) + SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) + UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error + SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error) + SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) +} + +type Transactions interface { + InsertTransaction(ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, eventID string) error + SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error) +} + +type StateSnapshot interface { + InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) + BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) +} + +type StateBlock interface { + BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries []types.StateEntry) (types.StateBlockNID, error) + BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) +} + +type RoomAliases interface { + InsertRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) (err error) + SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) + SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) + SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) + DeleteRoomAlias(ctx context.Context, alias string) (err error) +} + +type PreviousEvents interface { + InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error + // Check if the event reference exists + // Returns sql.ErrNoRows if the event reference doesn't exist. + SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error +} + +type Invites interface { + InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) + UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) + // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs + SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error) +} + +type MembershipState int64 + +const ( + MembershipStateLeaveOrBan MembershipState = 1 + MembershipStateInvite MembershipState = 2 + MembershipStateJoin MembershipState = 3 +) + +type Membership interface { + InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error + SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) + SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, error) + SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) + UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error +} diff --git a/roomserver/types/types.go b/roomserver/types/types.go index dfc112cfd..241e1e15d 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -16,7 +16,7 @@ package types import ( - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -75,6 +75,10 @@ func (a StateEntry) LessThan(b StateEntry) bool { // StateAtEvent is the state before and after a matrix event. type StateAtEvent struct { + // Should this state overwrite the latest events and memberships of the room? + // This might be necessary when rejoining a federated room after a period of + // absence, as our state and latest events will be out of date. + Overwrite bool // The state before the event. BeforeStateSnapshotNID StateSnapshotNID // The state entry for the event itself, allows us to calculate the state after the event. @@ -168,9 +172,9 @@ type RoomRecentEventsUpdater interface { MarkEventAsSent(eventNID EventNID) error // Build a membership updater for the target user in this room. // It will share the same transaction as this updater. - MembershipUpdater(targetUserNID EventStateKeyNID) (MembershipUpdater, error) + MembershipUpdater(targetUserNID EventStateKeyNID, isTargetLocalUser bool) (MembershipUpdater, error) // Implements Transaction so it can be committed or rolledback - common.Transaction + sqlutil.Transaction } // A MembershipUpdater is used to update the membership of a user in a room. @@ -195,7 +199,7 @@ type MembershipUpdater interface { // Returns a list of invite event IDs that this state change retired. SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) // Implements Transaction so it can be committed or rolledback. - common.Transaction + sqlutil.Transaction } // A MissingEventError is an error that happened because the roomserver was diff --git a/roomserver/version/version.go b/roomserver/version/version.go index f2a67e74d..f2b15ec39 100644 --- a/roomserver/version/version.go +++ b/roomserver/version/version.go @@ -20,87 +20,45 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// RoomVersionDescription contains information about a room version, -// namely whether it is marked as supported or stable in this server -// version. -// A version is supported if the server has some support for rooms -// that are this version. A version is marked as stable or unstable -// in order to hint whether the version should be used to clients -// calling the /capabilities endpoint. -// https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-capabilities -type RoomVersionDescription struct { - Supported bool - Stable bool -} - -var roomVersions = map[gomatrixserverlib.RoomVersion]RoomVersionDescription{ - gomatrixserverlib.RoomVersionV1: RoomVersionDescription{ - Supported: true, - Stable: true, - }, - gomatrixserverlib.RoomVersionV2: RoomVersionDescription{ - Supported: true, - Stable: true, - }, - gomatrixserverlib.RoomVersionV3: RoomVersionDescription{ - Supported: true, - Stable: true, - }, - gomatrixserverlib.RoomVersionV4: RoomVersionDescription{ - Supported: true, - Stable: true, - }, - gomatrixserverlib.RoomVersionV5: RoomVersionDescription{ - Supported: false, - Stable: false, - }, -} - // DefaultRoomVersion contains the room version that will, by // default, be used to create new rooms on this server. func DefaultRoomVersion() gomatrixserverlib.RoomVersion { - return gomatrixserverlib.RoomVersionV4 + return gomatrixserverlib.RoomVersionV5 } // RoomVersions returns a map of all known room versions to this // server. -func RoomVersions() map[gomatrixserverlib.RoomVersion]RoomVersionDescription { - return roomVersions +func RoomVersions() map[gomatrixserverlib.RoomVersion]gomatrixserverlib.RoomVersionDescription { + return gomatrixserverlib.RoomVersions() } // SupportedRoomVersions returns a map of descriptions for room // versions that are supported by this homeserver. -func SupportedRoomVersions() map[gomatrixserverlib.RoomVersion]RoomVersionDescription { - versions := make(map[gomatrixserverlib.RoomVersion]RoomVersionDescription) - for id, version := range RoomVersions() { - if version.Supported { - versions[id] = version - } - } - return versions +func SupportedRoomVersions() map[gomatrixserverlib.RoomVersion]gomatrixserverlib.RoomVersionDescription { + return gomatrixserverlib.SupportedRoomVersions() } // RoomVersion returns information about a specific room version. // An UnknownVersionError is returned if the version is not known // to the server. -func RoomVersion(version gomatrixserverlib.RoomVersion) (RoomVersionDescription, error) { - if version, ok := roomVersions[version]; ok { +func RoomVersion(version gomatrixserverlib.RoomVersion) (gomatrixserverlib.RoomVersionDescription, error) { + if version, ok := gomatrixserverlib.RoomVersions()[version]; ok { return version, nil } - return RoomVersionDescription{}, UnknownVersionError{version} + return gomatrixserverlib.RoomVersionDescription{}, UnknownVersionError{version} } // SupportedRoomVersion returns information about a specific room // version. An UnknownVersionError is returned if the version is not // known to the server, or an UnsupportedVersionError is returned if // the version is known but specifically marked as unsupported. -func SupportedRoomVersion(version gomatrixserverlib.RoomVersion) (RoomVersionDescription, error) { +func SupportedRoomVersion(version gomatrixserverlib.RoomVersion) (gomatrixserverlib.RoomVersionDescription, error) { result, err := RoomVersion(version) if err != nil { - return RoomVersionDescription{}, err + return gomatrixserverlib.RoomVersionDescription{}, err } if !result.Supported { - return RoomVersionDescription{}, UnsupportedVersionError{version} + return gomatrixserverlib.RoomVersionDescription{}, UnsupportedVersionError{version} } return result, nil } diff --git a/serverkeyapi/api/api.go b/serverkeyapi/api/api.go new file mode 100644 index 000000000..7af626345 --- /dev/null +++ b/serverkeyapi/api/api.go @@ -0,0 +1,40 @@ +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +type ServerKeyInternalAPI interface { + gomatrixserverlib.KeyDatabase + + KeyRing() *gomatrixserverlib.KeyRing + + InputPublicKeys( + ctx context.Context, + request *InputPublicKeysRequest, + response *InputPublicKeysResponse, + ) error + + QueryPublicKeys( + ctx context.Context, + request *QueryPublicKeysRequest, + response *QueryPublicKeysResponse, + ) error +} + +type QueryPublicKeysRequest struct { + Requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp `json:"requests"` +} + +type QueryPublicKeysResponse struct { + Results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"results"` +} + +type InputPublicKeysRequest struct { + Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"` +} + +type InputPublicKeysResponse struct { +} diff --git a/serverkeyapi/internal/api.go b/serverkeyapi/internal/api.go new file mode 100644 index 000000000..02028c60e --- /dev/null +++ b/serverkeyapi/internal/api.go @@ -0,0 +1,250 @@ +package internal + +import ( + "context" + "crypto/ed25519" + "fmt" + "time" + + "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +type ServerKeyAPI struct { + api.ServerKeyInternalAPI + + ServerName gomatrixserverlib.ServerName + ServerPublicKey ed25519.PublicKey + ServerKeyID gomatrixserverlib.KeyID + ServerKeyValidity time.Duration + + OurKeyRing gomatrixserverlib.KeyRing + FedClient *gomatrixserverlib.FederationClient +} + +func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing { + // Return a keyring that forces requests to be proxied through the + // below functions. That way we can enforce things like validity + // and keeping the cache up-to-date. + return &gomatrixserverlib.KeyRing{ + KeyDatabase: s, + KeyFetchers: []gomatrixserverlib.KeyFetcher{}, + } +} + +func (s *ServerKeyAPI) StoreKeys( + _ context.Context, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + // Run in a background context - we don't want to stop this work just + // because the caller gives up waiting. + ctx := context.Background() + + // Store any keys that we were given in our database. + return s.OurKeyRing.KeyDatabase.StoreKeys(ctx, results) +} + +func (s *ServerKeyAPI) FetchKeys( + _ context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + // Run in a background context - we don't want to stop this work just + // because the caller gives up waiting. + ctx := context.Background() + now := gomatrixserverlib.AsTimestamp(time.Now()) + results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} + origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{} + for k, v := range requests { + origRequests[k] = v + } + + // First, check if any of these key checks are for our own keys. If + // they are then we will satisfy them directly. + s.handleLocalKeys(ctx, requests, results) + + // Then consult our local database and see if we have the requested + // keys. These might come from a cache, depending on the database + // implementation used. + if err := s.handleDatabaseKeys(ctx, now, requests, results); err != nil { + return nil, err + } + + // For any key requests that we still have outstanding, next try to + // fetch them directly. We'll go through each of the key fetchers to + // ask for the remaining keys + for _, fetcher := range s.OurKeyRing.KeyFetchers { + // If there are no more keys to look up then stop. + if len(requests) == 0 { + break + } + + // Ask the fetcher to look up our keys. + if err := s.handleFetcherKeys(ctx, now, fetcher, requests, results); err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + }).Errorf("Failed to retrieve %d key(s)", len(requests)) + continue + } + } + + // Check that we've actually satisfied all of the key requests that we + // were given. We should report an error if we didn't. + for req := range origRequests { + if _, ok := results[req]; !ok { + // The results don't contain anything for this specific request, so + // we've failed to satisfy it from local keys, database keys or from + // all of the fetchers. Report an error. + logrus.Warnf("Failed to retrieve key %q for server %q", req.KeyID, req.ServerName) + return results, fmt.Errorf( + "server key API failed to satisfy key request for server %q key ID %q", + req.ServerName, req.KeyID, + ) + } + } + + // Return the keys. + return results, nil +} + +func (s *ServerKeyAPI) FetcherName() string { + return fmt.Sprintf("ServerKeyAPI (wrapping %q)", s.OurKeyRing.KeyDatabase.FetcherName()) +} + +// handleLocalKeys handles cases where the key request contains +// a request for our own server keys. +func (s *ServerKeyAPI) handleLocalKeys( + _ context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) { + for req := range requests { + if req.ServerName == s.ServerName { + // We found a key request that is supposed to be for our own + // keys. Remove it from the request list so we don't hit the + // database or the fetchers for it. + delete(requests, req) + + // Insert our own key into the response. + results[req] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: gomatrixserverlib.VerifyKey{ + Key: gomatrixserverlib.Base64Bytes(s.ServerPublicKey), + }, + ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, + ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(s.ServerKeyValidity)), + } + } + } +} + +// handleDatabaseKeys handles cases where the key requests can be +// satisfied from our local database/cache. +func (s *ServerKeyAPI) handleDatabaseKeys( + ctx context.Context, + now gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + // Ask the database/cache for the keys. + dbResults, err := s.OurKeyRing.KeyDatabase.FetchKeys(ctx, requests) + if err != nil { + return err + } + + // We successfully got some keys. Add them to the results. + for req, res := range dbResults { + // The key we've retrieved from the database/cache might + // have passed its validity period, but right now, it's + // the best thing we've got, and it might be sufficient to + // verify a past event. + results[req] = res + + // If the key is valid right now then we can also remove it + // from the request list as we don't need to fetch it again + // in that case. If the key isn't valid right now, then by + // leaving it in the 'requests' map, we'll try to update the + // key using the fetchers in handleFetcherKeys. + if res.WasValidAt(now, true) { + delete(requests, req) + } + } + return nil +} + +// handleFetcherKeys handles cases where a fetcher can satisfy +// the remaining requests. +func (s *ServerKeyAPI) handleFetcherKeys( + ctx context.Context, + now gomatrixserverlib.Timestamp, + fetcher gomatrixserverlib.KeyFetcher, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + logrus.WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + }).Infof("Fetching %d key(s)", len(requests)) + + // Create a context that limits our requests to 30 seconds. + fetcherCtx, fetcherCancel := context.WithTimeout(ctx, time.Second*30) + defer fetcherCancel() + + // Try to fetch the keys. + fetcherResults, err := fetcher.FetchKeys(fetcherCtx, requests) + if err != nil { + return err + } + + // Build a map of the results that we want to commit to the + // database. We do this in a separate map because otherwise we + // might end up trying to rewrite database entries. + storeResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} + + // Now let's look at the results that we got from this fetcher. + for req, res := range fetcherResults { + if prev, ok := results[req]; ok { + // We've already got a previous entry for this request + // so let's see if the newly retrieved one contains a more + // up-to-date validity period. + if res.ValidUntilTS > prev.ValidUntilTS { + // This key is newer than the one we had so let's store + // it in the database. + if req.ServerName != s.ServerName { + storeResults[req] = res + } + } + } else { + // We didn't already have a previous entry for this request + // so store it in the database anyway for now. + if req.ServerName != s.ServerName { + storeResults[req] = res + } + } + + // Update the results map with this new result. If nothing + // else, we can try verifying against this key. + results[req] = res + + // If the key is valid right now then we can remove it from the + // request list as we won't need to re-fetch it. + if res.WasValidAt(now, true) { + delete(requests, req) + } + } + + // Store the keys from our store map. + if err = s.OurKeyRing.KeyDatabase.StoreKeys(ctx, storeResults); err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + "database_name": s.OurKeyRing.KeyDatabase.FetcherName(), + }).Errorf("Failed to store keys in the database") + return fmt.Errorf("server key API failed to store retrieved keys: %w", err) + } + + if len(storeResults) > 0 { + logrus.WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + }).Infof("Updated %d of %d key(s) in database", len(storeResults), len(results)) + } + + return nil +} diff --git a/serverkeyapi/inthttp/client.go b/serverkeyapi/inthttp/client.go new file mode 100644 index 000000000..39ab8c6c5 --- /dev/null +++ b/serverkeyapi/inthttp/client.go @@ -0,0 +1,132 @@ +package inthttp + +import ( + "context" + "errors" + "net/http" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/opentracing/opentracing-go" +) + +// HTTP paths for the internal HTTP APIs +const ( + ServerKeyInputPublicKeyPath = "/serverkeyapi/inputPublicKey" + ServerKeyQueryPublicKeyPath = "/serverkeyapi/queryPublicKey" +) + +// NewServerKeyClient creates a ServerKeyInternalAPI implemented by talking to a HTTP POST API. +// If httpClient is nil an error is returned +func NewServerKeyClient( + serverKeyAPIURL string, + httpClient *http.Client, + cache caching.ServerKeyCache, +) (api.ServerKeyInternalAPI, error) { + if httpClient == nil { + return nil, errors.New("NewRoomserverInternalAPIHTTP: httpClient is ") + } + return &httpServerKeyInternalAPI{ + serverKeyAPIURL: serverKeyAPIURL, + httpClient: httpClient, + cache: cache, + }, nil +} + +type httpServerKeyInternalAPI struct { + serverKeyAPIURL string + httpClient *http.Client + cache caching.ServerKeyCache +} + +func (s *httpServerKeyInternalAPI) KeyRing() *gomatrixserverlib.KeyRing { + // This is a bit of a cheat - we tell gomatrixserverlib that this API is + // both the key database and the key fetcher. While this does have the + // rather unfortunate effect of preventing gomatrixserverlib from handling + // key fetchers directly, we can at least reimplement this behaviour on + // the other end of the API. + return &gomatrixserverlib.KeyRing{ + KeyDatabase: s, + KeyFetchers: []gomatrixserverlib.KeyFetcher{}, + } +} + +func (s *httpServerKeyInternalAPI) FetcherName() string { + return "httpServerKeyInternalAPI" +} + +func (s *httpServerKeyInternalAPI) StoreKeys( + _ context.Context, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + // Run in a background context - we don't want to stop this work just + // because the caller gives up waiting. + ctx := context.Background() + request := api.InputPublicKeysRequest{ + Keys: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), + } + response := api.InputPublicKeysResponse{} + for req, res := range results { + request.Keys[req] = res + s.cache.StoreServerKey(req, res) + } + return s.InputPublicKeys(ctx, &request, &response) +} + +func (s *httpServerKeyInternalAPI) FetchKeys( + _ context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + // Run in a background context - we don't want to stop this work just + // because the caller gives up waiting. + ctx := context.Background() + result := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) + request := api.QueryPublicKeysRequest{ + Requests: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp), + } + response := api.QueryPublicKeysResponse{ + Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), + } + for req, ts := range requests { + if res, ok := s.cache.GetServerKey(req, ts); ok { + result[req] = res + continue + } + request.Requests[req] = ts + } + err := s.QueryPublicKeys(ctx, &request, &response) + if err != nil { + return nil, err + } + for req, res := range response.Results { + result[req] = res + s.cache.StoreServerKey(req, res) + } + return result, nil +} + +func (h *httpServerKeyInternalAPI) InputPublicKeys( + ctx context.Context, + request *api.InputPublicKeysRequest, + response *api.InputPublicKeysResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey") + defer span.Finish() + + apiURL := h.serverKeyAPIURL + ServerKeyInputPublicKeyPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpServerKeyInternalAPI) QueryPublicKeys( + ctx context.Context, + request *api.QueryPublicKeysRequest, + response *api.QueryPublicKeysResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey") + defer span.Finish() + + apiURL := h.serverKeyAPIURL + ServerKeyQueryPublicKeyPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/serverkeyapi/inthttp/server.go b/serverkeyapi/inthttp/server.go new file mode 100644 index 000000000..cd4748392 --- /dev/null +++ b/serverkeyapi/inthttp/server.go @@ -0,0 +1,43 @@ +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/util" +) + +func AddRoutes(s api.ServerKeyInternalAPI, internalAPIMux *mux.Router, cache caching.ServerKeyCache) { + internalAPIMux.Handle(ServerKeyQueryPublicKeyPath, + httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse { + request := api.QueryPublicKeysRequest{} + response := api.QueryPublicKeysResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + keys, err := s.FetchKeys(req.Context(), request.Requests) + if err != nil { + return util.ErrorResponse(err) + } + response.Results = keys + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(ServerKeyInputPublicKeyPath, + httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse { + request := api.InputPublicKeysRequest{} + response := api.InputPublicKeysResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.StoreKeys(req.Context(), request.Keys); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/serverkeyapi/serverkeyapi.go b/serverkeyapi/serverkeyapi.go new file mode 100644 index 000000000..cddd392ed --- /dev/null +++ b/serverkeyapi/serverkeyapi.go @@ -0,0 +1,96 @@ +package serverkeyapi + +import ( + "crypto/ed25519" + "encoding/base64" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/dendrite/serverkeyapi/internal" + "github.com/matrix-org/dendrite/serverkeyapi/inthttp" + "github.com/matrix-org/dendrite/serverkeyapi/storage" + "github.com/matrix-org/dendrite/serverkeyapi/storage/cache" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions +// on the given input API. +func AddInternalRoutes(router *mux.Router, intAPI api.ServerKeyInternalAPI, caches *caching.Caches) { + inthttp.AddRoutes(intAPI, router, caches) +} + +// NewInternalAPI returns a concerete implementation of the internal API. Callers +// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. +func NewInternalAPI( + cfg *config.Dendrite, + fedClient *gomatrixserverlib.FederationClient, + caches *caching.Caches, +) api.ServerKeyInternalAPI { + innerDB, err := storage.NewDatabase( + string(cfg.Database.ServerKey), + cfg.DbProperties(), + cfg.Matrix.ServerName, + cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey), + cfg.Matrix.KeyID, + ) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to server key database") + } + + serverKeyDB, err := cache.NewKeyDatabase(innerDB, caches) + if err != nil { + logrus.WithError(err).Panicf("failed to set up caching wrapper for server key database") + } + + internalAPI := internal.ServerKeyAPI{ + ServerName: cfg.Matrix.ServerName, + ServerPublicKey: cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey), + ServerKeyID: cfg.Matrix.KeyID, + ServerKeyValidity: cfg.Matrix.KeyValidityPeriod, + FedClient: fedClient, + OurKeyRing: gomatrixserverlib.KeyRing{ + KeyFetchers: []gomatrixserverlib.KeyFetcher{ + &gomatrixserverlib.DirectKeyFetcher{ + Client: fedClient.Client, + }, + }, + KeyDatabase: serverKeyDB, + }, + } + + var b64e = base64.StdEncoding.WithPadding(base64.NoPadding) + for _, ps := range cfg.Matrix.KeyPerspectives { + perspective := &gomatrixserverlib.PerspectiveKeyFetcher{ + PerspectiveServerName: ps.ServerName, + PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{}, + Client: fedClient.Client, + } + + for _, key := range ps.Keys { + rawkey, err := b64e.DecodeString(key.PublicKey) + if err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "server_name": ps.ServerName, + "public_key": key.PublicKey, + }).Warn("Couldn't parse perspective key") + continue + } + perspective.PerspectiveServerKeys[key.KeyID] = rawkey + } + + internalAPI.OurKeyRing.KeyFetchers = append( + internalAPI.OurKeyRing.KeyFetchers, + perspective, + ) + + logrus.WithFields(logrus.Fields{ + "server_name": ps.ServerName, + "num_public_keys": len(ps.Keys), + }).Info("Enabled perspective key fetcher") + } + + return &internalAPI +} diff --git a/serverkeyapi/serverkeyapi_test.go b/serverkeyapi/serverkeyapi_test.go new file mode 100644 index 000000000..3368f5b2a --- /dev/null +++ b/serverkeyapi/serverkeyapi_test.go @@ -0,0 +1,315 @@ +package serverkeyapi + +import ( + "bytes" + "context" + "crypto/ed25519" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "reflect" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type server struct { + name gomatrixserverlib.ServerName // server name + validity time.Duration // key validity duration from now + config *config.Dendrite // skeleton config, from TestMain + fedclient *gomatrixserverlib.FederationClient // uses MockRoundTripper + cache *caching.Caches // server-specific cache + api api.ServerKeyInternalAPI // server-specific server key API +} + +func (s *server) renew() { + // This updates the validity period to be an hour in the + // future, which is particularly useful in server A and + // server C's cases which have validity either as now or + // in the past. + s.validity = time.Hour + s.config.Matrix.KeyValidityPeriod = s.validity +} + +var ( + serverKeyID = gomatrixserverlib.KeyID("ed25519:auto") + serverA = &server{name: "a.com", validity: time.Duration(0)} // expires now + serverB = &server{name: "b.com", validity: time.Hour} // expires in an hour + serverC = &server{name: "c.com", validity: -time.Hour} // expired an hour ago +) + +var servers = map[string]*server{ + "a.com": serverA, + "b.com": serverB, + "c.com": serverC, +} + +func TestMain(m *testing.M) { + // Set up the server key API for each "server" that we + // will use in our tests. + for _, s := range servers { + // Generate a new key. + _, testPriv, err := ed25519.GenerateKey(nil) + if err != nil { + panic("can't generate identity key: " + err.Error()) + } + + // Create a new cache but don't enable prometheus! + s.cache, err = caching.NewInMemoryLRUCache(false) + if err != nil { + panic("can't create cache: " + err.Error()) + } + + // Draw up just enough Dendrite config for the server key + // API to work. + s.config = &config.Dendrite{} + s.config.SetDefaults() + s.config.Matrix.ServerName = gomatrixserverlib.ServerName(s.name) + s.config.Matrix.PrivateKey = testPriv + s.config.Matrix.KeyID = serverKeyID + s.config.Matrix.KeyValidityPeriod = s.validity + s.config.Database.ServerKey = config.DataSource("file::memory:") + + // Create a transport which redirects federation requests to + // the mock round tripper. Since we're not *really* listening for + // federation requests then this will return the key instead. + transport := &http.Transport{} + transport.RegisterProtocol("matrix", &MockRoundTripper{}) + + // Create the federation client. + s.fedclient = gomatrixserverlib.NewFederationClientWithTransport( + s.config.Matrix.ServerName, serverKeyID, testPriv, transport, + ) + + // Finally, build the server key APIs. + s.api = NewInternalAPI(s.config, s.fedclient, s.cache) + } + + // Now that we have built our server key APIs, start the + // rest of the tests. + os.Exit(m.Run()) +} + +type MockRoundTripper struct{} + +func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) { + // Check if the request is looking for keys from a server that + // we know about in the test. The only reason this should go wrong + // is if the test is broken. + s, ok := servers[req.Host] + if !ok { + return nil, fmt.Errorf("server not known: %s", req.Host) + } + + // We're intercepting /matrix/key/v2/server requests here, so check + // that the URL supplied in the request is for that. + if req.URL.Path != "/_matrix/key/v2/server" { + return nil, fmt.Errorf("unexpected request path: %s", req.URL.Path) + } + + // Get the keys and JSON-ify them. + keys := routing.LocalKeys(s.config) + body, err := json.MarshalIndent(keys.JSON, "", " ") + if err != nil { + return nil, err + } + + // And respond. + res = &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader(body)), + } + return +} + +func TestServersRequestOwnKeys(t *testing.T) { + // Each server will request its own keys. There's no reason + // for this to fail as each server should know its own keys. + + for name, s := range servers { + req := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: s.name, + KeyID: serverKeyID, + } + res, err := s.api.FetchKeys( + context.Background(), + map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ + req: gomatrixserverlib.AsTimestamp(time.Now()), + }, + ) + if err != nil { + t.Fatalf("server could not fetch own key: %s", err) + } + if _, ok := res[req]; !ok { + t.Fatalf("server didn't return its own key in the results") + } + t.Logf("%s's key expires at %s\n", name, res[req].ValidUntilTS.Time()) + } +} + +func TestCachingBehaviour(t *testing.T) { + // Server A will request Server B's key, which has a validity + // period of an hour from now. We should retrieve the key and + // it should make it into the cache automatically. + + req := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: serverB.name, + KeyID: serverKeyID, + } + ts := gomatrixserverlib.AsTimestamp(time.Now()) + + res, err := serverA.api.FetchKeys( + context.Background(), + map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ + req: ts, + }, + ) + if err != nil { + t.Fatalf("server A failed to retrieve server B key: %s", err) + } + if len(res) != 1 { + t.Fatalf("server B should have returned one key but instead returned %d keys", len(res)) + } + if _, ok := res[req]; !ok { + t.Fatalf("server B isn't included in the key fetch response") + } + + // At this point, if the previous key request was a success, + // then the cache should now contain the key. Check if that's + // the case - if it isn't then there's something wrong with + // the cache implementation or we failed to get the key. + + cres, ok := serverA.cache.GetServerKey(req, ts) + if !ok { + t.Fatalf("server B key should be in cache but isn't") + } + if !reflect.DeepEqual(cres, res[req]) { + t.Fatalf("the cached result from server B wasn't what server B gave us") + } + + // If we ask the cache for the same key but this time for an event + // that happened in +30 minutes. Since the validity period is for + // another hour, then we should get a response back from the cache. + + _, ok = serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*30)), + ) + if !ok { + t.Fatalf("server B key isn't in cache when it should be (+30 minutes)") + } + + // If we ask the cache for the same key but this time for an event + // that happened in +90 minutes then we should expect to get no + // cache result. This is because the cache shouldn't return a result + // that is obviously past the validity of the event. + + _, ok = serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*90)), + ) + if ok { + t.Fatalf("server B key is in cache when it shouldn't be (+90 minutes)") + } +} + +func TestRenewalBehaviour(t *testing.T) { + // Server A will request Server C's key but their validity period + // is an hour in the past. We'll retrieve the key as, even though it's + // past its validity, it will be able to verify past events. + + req := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: serverC.name, + KeyID: serverKeyID, + } + + res, err := serverA.api.FetchKeys( + context.Background(), + map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ + req: gomatrixserverlib.AsTimestamp(time.Now()), + }, + ) + if err != nil { + t.Fatalf("server A failed to retrieve server C key: %s", err) + } + if len(res) != 1 { + t.Fatalf("server C should have returned one key but instead returned %d keys", len(res)) + } + if _, ok := res[req]; !ok { + t.Fatalf("server C isn't included in the key fetch response") + } + + // If we ask the cache for the server key for an event that happened + // 90 minutes ago then we should get a cache result, as the key hadn't + // passed its validity by that point. The fact that the key is now in + // the cache is, in itself, proof that we successfully retrieved the + // key before. + + oldcached, ok := serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*90)), + ) + if !ok { + t.Fatalf("server C key isn't in cache when it should be (-90 minutes)") + } + + // If we now ask the cache for the same key but this time for an event + // that only happened 30 minutes ago then we shouldn't get a cached + // result, as the event happened after the key validity expired. This + // is really just for sanity checking. + + _, ok = serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)), + ) + if ok { + t.Fatalf("server B key is in cache when it shouldn't be (-30 minutes)") + } + + // We're now going to kick server C into renewing its key. Since we're + // happy at this point that the key that we already have is from the past + // then repeating a key fetch should cause us to try and renew the key. + // If so, then the new key will end up in our cache. + + serverC.renew() + + res, err = serverA.api.FetchKeys( + context.Background(), + map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ + req: gomatrixserverlib.AsTimestamp(time.Now()), + }, + ) + if err != nil { + t.Fatalf("server A failed to retrieve server C key: %s", err) + } + if len(res) != 1 { + t.Fatalf("server C should have returned one key but instead returned %d keys", len(res)) + } + if _, ok = res[req]; !ok { + t.Fatalf("server C isn't included in the key fetch response") + } + + // We're now going to ask the cache what the new key validity is. If + // it is still the same as the previous validity then we've failed to + // retrieve the renewed key. If it's newer then we've successfully got + // the renewed key. + + newcached, ok := serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)), + ) + if !ok { + t.Fatalf("server B key isn't in cache when it shouldn't be (post-renewal)") + } + if oldcached.ValidUntilTS >= newcached.ValidUntilTS { + t.Fatalf("the server B key should have been renewed but wasn't") + } + t.Log(res) +} diff --git a/serverkeyapi/storage/cache/keydb.go b/serverkeyapi/storage/cache/keydb.go new file mode 100644 index 000000000..2063dfc55 --- /dev/null +++ b/serverkeyapi/storage/cache/keydb.go @@ -0,0 +1,68 @@ +package cache + +import ( + "context" + "errors" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/gomatrixserverlib" +) + +// A Database implements gomatrixserverlib.KeyDatabase and is used to store +// the public keys for other matrix servers. +type KeyDatabase struct { + inner gomatrixserverlib.KeyDatabase + cache caching.ServerKeyCache +} + +func NewKeyDatabase(inner gomatrixserverlib.KeyDatabase, cache caching.ServerKeyCache) (*KeyDatabase, error) { + if inner == nil { + return nil, errors.New("inner database can't be nil") + } + if cache == nil { + return nil, errors.New("cache can't be nil") + } + return &KeyDatabase{ + inner: inner, + cache: cache, + }, nil +} + +// FetcherName implements KeyFetcher +func (d KeyDatabase) FetcherName() string { + return "InMemoryKeyCache" +} + +// FetchKeys implements gomatrixserverlib.KeyDatabase +func (d *KeyDatabase) FetchKeys( + ctx context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) + for req, ts := range requests { + if res, cached := d.cache.GetServerKey(req, ts); cached { + results[req] = res + delete(requests, req) + } + } + fromDB, err := d.inner.FetchKeys(ctx, requests) + if err != nil { + return results, err + } + for req, res := range fromDB { + results[req] = res + d.cache.StoreServerKey(req, res) + } + return results, nil +} + +// StoreKeys implements gomatrixserverlib.KeyDatabase +func (d *KeyDatabase) StoreKeys( + ctx context.Context, + keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + for req, res := range keyMap { + d.cache.StoreServerKey(req, res) + } + return d.inner.StoreKeys(ctx, keyMap) +} diff --git a/common/keydb/interface.go b/serverkeyapi/storage/interface.go similarity index 96% rename from common/keydb/interface.go rename to serverkeyapi/storage/interface.go index c9a20fdd9..3a67ac55a 100644 --- a/common/keydb/interface.go +++ b/serverkeyapi/storage/interface.go @@ -1,4 +1,4 @@ -package keydb +package storage import ( "context" diff --git a/common/keydb/keydb.go b/serverkeyapi/storage/keydb.go similarity index 69% rename from common/keydb/keydb.go rename to serverkeyapi/storage/keydb.go index fe6d87fc8..c28c4de1e 100644 --- a/common/keydb/keydb.go +++ b/serverkeyapi/storage/keydb.go @@ -14,35 +14,37 @@ // +build !wasm -package keydb +package storage import ( "net/url" "golang.org/x/crypto/ed25519" - "github.com/matrix-org/dendrite/common/keydb/postgres" - "github.com/matrix-org/dendrite/common/keydb/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/serverkeyapi/storage/postgres" + "github.com/matrix-org/dendrite/serverkeyapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a database connection. func NewDatabase( dataSourceName string, + dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName, serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, ) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID) + return postgres.NewDatabase(dataSourceName, dbProperties, serverName, serverKey, serverKeyID) } switch uri.Scheme { case "postgres": - return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID) + return postgres.NewDatabase(dataSourceName, dbProperties, serverName, serverKey, serverKeyID) case "file": return sqlite3.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID) default: - return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID) + return postgres.NewDatabase(dataSourceName, dbProperties, serverName, serverKey, serverKeyID) } } diff --git a/common/keydb/keydb_wasm.go b/serverkeyapi/storage/keydb_wasm.go similarity index 86% rename from common/keydb/keydb_wasm.go rename to serverkeyapi/storage/keydb_wasm.go index 807ed40b4..de66a1d63 100644 --- a/common/keydb/keydb_wasm.go +++ b/serverkeyapi/storage/keydb_wasm.go @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package keydb +// +build wasm + +package storage import ( "fmt" @@ -20,13 +22,15 @@ import ( "golang.org/x/crypto/ed25519" - "github.com/matrix-org/dendrite/common/keydb/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/serverkeyapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a database connection. func NewDatabase( dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam serverName gomatrixserverlib.ServerName, serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, diff --git a/common/keydb/sqlite3/keydb.go b/serverkeyapi/storage/postgres/keydb.go similarity index 77% rename from common/keydb/sqlite3/keydb.go rename to serverkeyapi/storage/postgres/keydb.go index 82d2a491f..aaa4409be 100644 --- a/common/keydb/sqlite3/keydb.go +++ b/serverkeyapi/storage/postgres/keydb.go @@ -13,19 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sqlite3 +package postgres import ( "context" - "math" "golang.org/x/crypto/ed25519" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" - - _ "github.com/mattn/go-sqlite3" ) // A Database implements gomatrixserverlib.KeyDatabase and is used to store @@ -40,11 +36,12 @@ type Database struct { // Returns an error if there was a problem talking to the database. func NewDatabase( dataSourceName string, + dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName, serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, ) (*Database, error) { - db, err := sqlutil.Open(common.SQLiteDriverName(), dataSourceName) + db, err := sqlutil.Open("postgres", dataSourceName, dbProperties) if err != nil { return nil, err } @@ -53,34 +50,12 @@ func NewDatabase( if err != nil { return nil, err } - // Store our own keys so that we don't end up making HTTP requests to find our - // own keys - index := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: serverName, - KeyID: serverKeyID, - } - value := gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: gomatrixserverlib.VerifyKey{ - Key: gomatrixserverlib.Base64String(serverKey), - }, - ValidUntilTS: math.MaxUint64 >> 1, - ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, - } - err = d.StoreKeys( - context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{ - index: value, - }, - ) - if err != nil { - return nil, err - } return d, nil } // FetcherName implements KeyFetcher func (d Database) FetcherName() string { - return "KeyDatabase" + return "PostgresKeyDatabase" } // FetchKeys implements gomatrixserverlib.KeyDatabase diff --git a/common/keydb/postgres/server_key_table.go b/serverkeyapi/storage/postgres/server_key_table.go similarity index 97% rename from common/keydb/postgres/server_key_table.go rename to serverkeyapi/storage/postgres/server_key_table.go index 0434eb8b1..87f1c211c 100644 --- a/common/keydb/postgres/server_key_table.go +++ b/serverkeyapi/storage/postgres/server_key_table.go @@ -19,9 +19,8 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" - "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) @@ -93,7 +92,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed") results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} for rows.Next() { var serverName string diff --git a/common/keydb/postgres/keydb.go b/serverkeyapi/storage/sqlite3/keydb.go similarity index 80% rename from common/keydb/postgres/keydb.go rename to serverkeyapi/storage/sqlite3/keydb.go index 2879683e0..dc72b79eb 100644 --- a/common/keydb/postgres/keydb.go +++ b/serverkeyapi/storage/sqlite3/keydb.go @@ -13,16 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -package postgres +package sqlite3 import ( "context" - "math" "golang.org/x/crypto/ed25519" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + + _ "github.com/mattn/go-sqlite3" ) // A Database implements gomatrixserverlib.KeyDatabase and is used to store @@ -41,7 +42,11 @@ func NewDatabase( serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, ) (*Database, error) { - db, err := sqlutil.Open("postgres", dataSourceName) + cs, err := sqlutil.ParseFileURI(dataSourceName) + if err != nil { + return nil, err + } + db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil) if err != nil { return nil, err } @@ -50,25 +55,6 @@ func NewDatabase( if err != nil { return nil, err } - // Store our own keys so that we don't end up making HTTP requests to find our - // own keys - index := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: serverName, - KeyID: serverKeyID, - } - value := gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: gomatrixserverlib.VerifyKey{ - Key: gomatrixserverlib.Base64String(serverKey), - }, - ValidUntilTS: math.MaxUint64 >> 1, - ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, - } - err = d.StoreKeys( - context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{ - index: value, - }, - ) if err != nil { return nil, err } @@ -77,7 +63,7 @@ func NewDatabase( // FetcherName implements KeyFetcher func (d Database) FetcherName() string { - return "KeyDatabase" + return "SqliteKeyDatabase" } // FetchKeys implements gomatrixserverlib.KeyDatabase diff --git a/common/keydb/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go similarity index 81% rename from common/keydb/sqlite3/server_key_table.go rename to serverkeyapi/storage/sqlite3/server_key_table.go index ba1cc0606..4f03dccbb 100644 --- a/common/keydb/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -20,10 +20,9 @@ import ( "database/sql" "strings" - lru "github.com/hashicorp/golang-lru" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) const serverKeysSchema = ` @@ -66,16 +65,10 @@ type serverKeyStatements struct { db *sql.DB bulkSelectServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt - - cache *lru.Cache // nameAndKeyID => gomatrixserverlib.PublicKeyLookupResult } func (s *serverKeyStatements) prepare(db *sql.DB) (err error) { s.db = db - s.cache, err = lru.New(64) - if err != nil { - return - } _, err = db.Exec(serverKeysSchema) if err != nil { return @@ -98,22 +91,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys( nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) } - // If we can satisfy all of the requests from the cache, do so. TODO: Allow partial matches with merges. - cacheResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - for request := range requests { - r, ok := s.cache.Get(nameAndKeyID(request)) - if !ok { - break - } - cacheResult := r.(gomatrixserverlib.PublicKeyLookupResult) - cacheResults[request] = cacheResult - } - if len(cacheResults) == len(requests) { - util.GetLogger(ctx).Infof("KeyDB cache hit for %d keys", len(cacheResults)) - return cacheResults, nil - } - - query := strings.Replace(bulkSelectServerKeysSQL, "($1)", common.QueryVariadic(len(nameAndKeyIDs)), 1) + query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1) iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) for i, v := range nameAndKeyIDs { @@ -124,7 +102,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed") results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} for rows.Next() { var serverName string @@ -158,7 +136,6 @@ func (s *serverKeyStatements) upsertServerKeys( request gomatrixserverlib.PublicKeyLookupRequest, key gomatrixserverlib.PublicKeyLookupResult, ) error { - s.cache.Add(nameAndKeyID(request), key) _, err := s.upsertServerKeysStmt.ExecContext( ctx, string(request.ServerName), diff --git a/show-expected-fail-tests.sh b/show-expected-fail-tests.sh index 0a4c7be87..320d4ebd3 100755 --- a/show-expected-fail-tests.sh +++ b/show-expected-fail-tests.sh @@ -80,8 +80,8 @@ done <<< "${passed_but_expected_fail}" # TODO: Check that the same test doesn't appear twice in the whitelist|blacklist # Trim test output strings -tests_to_add=$(echo -e $tests_to_add | xargs -d '\n') -already_in_whitelist=$(echo -e $already_in_whitelist | xargs -d '\n') +tests_to_add=$(IFS=$'\n' echo "${tests_to_add[*]%%'\n'}") +already_in_whitelist=$(IFS=$'\n' echo "${already_in_whitelist[*]%%'\n'}") # Format output with markdown for buildkite annotation rendering purposes if [ -n "${tests_to_add}" ] && [ -n "${already_in_whitelist}" ]; then @@ -91,14 +91,14 @@ fi if [ -n "${tests_to_add}" ]; then echo "**ERROR**: The following tests passed but are not present in \`$2\`. Please append them to the file:" echo "\`\`\`" - echo -e "${tests_to_add}" + echo -e "${tests_to_add}" echo "\`\`\`" fi if [ -n "${already_in_whitelist}" ]; then echo "**WARN**: Tests in the whitelist still marked as **expected fail**:" echo "\`\`\`" - echo -e "${already_in_whitelist}" + echo -e "${already_in_whitelist}" echo "\`\`\`" fi diff --git a/syncapi/api/query.go b/syncapi/api/query.go deleted file mode 100644 index 2993829e0..000000000 --- a/syncapi/api/query.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "context" - "net/http" - - commonHTTP "github.com/matrix-org/dendrite/common/http" - "github.com/matrix-org/util" - opentracing "github.com/opentracing/opentracing-go" -) - -const ( - SyncAPIQuerySyncPath = "/api/syncapi/querySync" - SyncAPIQueryStatePath = "/api/syncapi/queryState" - SyncAPIQueryStateTypePath = "/api/syncapi/queryStateType" - SyncAPIQueryMessagesPath = "/api/syncapi/queryMessages" -) - -func NewSyncQueryAPIHTTP(syncapiURL string, httpClient *http.Client) SyncQueryAPI { - if httpClient == nil { - httpClient = http.DefaultClient - } - return &httpSyncQueryAPI{syncapiURL, httpClient} -} - -type httpSyncQueryAPI struct { - syncapiURL string - httpClient *http.Client -} - -type SyncQueryAPI interface { - QuerySync(ctx context.Context, request *QuerySyncRequest, response *QuerySyncResponse) error - QueryState(ctx context.Context, request *QueryStateRequest, response *QueryStateResponse) error - QueryStateType(ctx context.Context, request *QueryStateTypeRequest, response *QueryStateTypeResponse) error - QueryMessages(ctx context.Context, request *QueryMessagesRequest, response *QueryMessagesResponse) error -} - -type QuerySyncRequest struct{} - -type QueryStateRequest struct { - RoomID string -} - -type QueryStateTypeRequest struct { - RoomID string - EventType string - StateKey string -} - -type QueryMessagesRequest struct { - RoomID string -} - -type QuerySyncResponse util.JSONResponse -type QueryStateResponse util.JSONResponse -type QueryStateTypeResponse util.JSONResponse -type QueryMessagesResponse util.JSONResponse - -// QueryLatestEventsAndState implements SyncQueryAPI -func (h *httpSyncQueryAPI) QuerySync( - ctx context.Context, - request *QuerySyncRequest, - response *QuerySyncResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySync") - defer span.Finish() - - apiURL := h.syncapiURL + SyncAPIQuerySyncPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryStateAfterEvents implements SyncQueryAPI -func (h *httpSyncQueryAPI) QueryState( - ctx context.Context, - request *QueryStateRequest, - response *QueryStateResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryState") - defer span.Finish() - - apiURL := h.syncapiURL + SyncAPIQueryStatePath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryEventsByID implements SyncQueryAPI -func (h *httpSyncQueryAPI) QueryStateType( - ctx context.Context, - request *QueryStateTypeRequest, - response *QueryStateTypeResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateType") - defer span.Finish() - - apiURL := h.syncapiURL + SyncAPIQueryStateTypePath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryMembershipForUser implements SyncQueryAPI -func (h *httpSyncQueryAPI) QueryMessages( - ctx context.Context, - request *QueryMessagesRequest, - response *QueryMessagesResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMessages") - defer span.Finish() - - apiURL := h.syncapiURL + SyncAPIQueryMessagesPath - return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 17f2c522c..ad6290e3f 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -18,18 +18,19 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types" log "github.com/sirupsen/logrus" - sarama "gopkg.in/Shopify/sarama.v1" ) // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { - clientAPIConsumer *common.ContinualConsumer + clientAPIConsumer *internal.ContinualConsumer db storage.Database notifier *sync.Notifier } @@ -42,7 +43,7 @@ func NewOutputClientDataConsumer( store storage.Database, ) *OutputClientDataConsumer { - consumer := common.ContinualConsumer{ + consumer := internal.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputClientData), Consumer: kafkaConsumer, PartitionStore: store, @@ -67,7 +68,7 @@ func (s *OutputClientDataConsumer) Start() error { // sync stream position may race and be incorrectly calculated. func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error { // Parse out the event JSON - var output common.AccountData + var output eventutil.AccountData if err := json.Unmarshal(msg.Value, &output); err != nil { // If the message was invalid, log it and move on to the next message in the stream log.WithError(err).Errorf("client API server output log: message parse failure") @@ -90,7 +91,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error }).Panicf("could not save account data") } - s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.PaginationToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0)) return nil } diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go new file mode 100644 index 000000000..487018031 --- /dev/null +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -0,0 +1,113 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + log "github.com/sirupsen/logrus" +) + +// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. +type OutputSendToDeviceEventConsumer struct { + sendToDeviceConsumer *internal.ContinualConsumer + db storage.Database + serverName gomatrixserverlib.ServerName // our server name + notifier *sync.Notifier +} + +// NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer. +// Call Start() to begin consuming from the EDU server. +func NewOutputSendToDeviceEventConsumer( + cfg *config.Dendrite, + kafkaConsumer sarama.Consumer, + n *sync.Notifier, + store storage.Database, +) *OutputSendToDeviceEventConsumer { + + consumer := internal.ContinualConsumer{ + Topic: string(cfg.Kafka.Topics.OutputSendToDeviceEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + } + + s := &OutputSendToDeviceEventConsumer{ + sendToDeviceConsumer: &consumer, + db: store, + serverName: cfg.Matrix.ServerName, + notifier: n, + } + + consumer.ProcessMessage = s.onMessage + + return s +} + +// Start consuming from EDU api +func (s *OutputSendToDeviceEventConsumer) Start() error { + return s.sendToDeviceConsumer.Start() +} + +func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { + var output api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Value, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + return err + } + + _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + return err + } + if domain != s.serverName { + return nil + } + + util.GetLogger(context.TODO()).WithFields(log.Fields{ + "sender": output.Sender, + "user_id": output.UserID, + "device_id": output.DeviceID, + "event_type": output.Type, + }).Info("sync API received send-to-device event from EDU server") + + streamPos := s.db.AddSendToDevice() + + _, err = s.db.StoreNewSendForDeviceMessage( + context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent, + ) + if err != nil { + log.WithError(err).Errorf("failed to store send-to-device message") + return err + } + + s.notifier.OnNewSendToDevice( + output.UserID, + []string{output.DeviceID}, + types.NewStreamToken(0, streamPos), + ) + + return nil +} diff --git a/syncapi/consumers/eduserver.go b/syncapi/consumers/eduserver_typing.go similarity index 86% rename from syncapi/consumers/eduserver.go rename to syncapi/consumers/eduserver_typing.go index 5491c1e9f..12b1efbc0 100644 --- a/syncapi/consumers/eduserver.go +++ b/syncapi/consumers/eduserver_typing.go @@ -17,19 +17,19 @@ package consumers import ( "encoding/json" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types" log "github.com/sirupsen/logrus" - sarama "gopkg.in/Shopify/sarama.v1" ) // OutputTypingEventConsumer consumes events that originated in the EDU server. type OutputTypingEventConsumer struct { - typingConsumer *common.ContinualConsumer + typingConsumer *internal.ContinualConsumer db storage.Database notifier *sync.Notifier } @@ -43,7 +43,7 @@ func NewOutputTypingEventConsumer( store storage.Database, ) *OutputTypingEventConsumer { - consumer := common.ContinualConsumer{ + consumer := internal.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputTypingEvent), Consumer: kafkaConsumer, PartitionStore: store, @@ -65,9 +65,7 @@ func (s *OutputTypingEventConsumer) Start() error { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.notifier.OnNewEvent( nil, roomID, nil, - types.PaginationToken{ - EDUTypingPosition: types.StreamPosition(latestSyncPosition), - }, + types.NewStreamToken(0, types.StreamPosition(latestSyncPosition)), ) }) @@ -96,6 +94,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) } - s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.PaginationToken{EDUTypingPosition: typingPos}) + s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos)) return nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index f1e68c262..98be5bb73 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -17,25 +17,24 @@ package consumers import ( "context" "encoding/json" - "fmt" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" - sarama "gopkg.in/Shopify/sarama.v1" ) // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { - roomServerConsumer *common.ContinualConsumer - db storage.Database - notifier *sync.Notifier - query api.RoomserverQueryAPI + rsAPI api.RoomserverInternalAPI + rsConsumer *internal.ContinualConsumer + db storage.Database + notifier *sync.Notifier } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -44,19 +43,19 @@ func NewOutputRoomEventConsumer( kafkaConsumer sarama.Consumer, n *sync.Notifier, store storage.Database, - queryAPI api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, ) *OutputRoomEventConsumer { - consumer := common.ContinualConsumer{ + consumer := internal.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputRoomEvent), Consumer: kafkaConsumer, PartitionStore: store, } s := &OutputRoomEventConsumer{ - roomServerConsumer: &consumer, - db: store, - notifier: n, - query: queryAPI, + rsConsumer: &consumer, + db: store, + notifier: n, + rsAPI: rsAPI, } consumer.ProcessMessage = s.onMessage @@ -65,7 +64,7 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - return s.roomServerConsumer.Start() + return s.rsConsumer.Start() } // onMessage is called when the sync server receives a new event from the room server output log. @@ -99,23 +98,10 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( ctx context.Context, msg api.OutputNewRoomEvent, ) error { ev := msg.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "room_version": ev.RoomVersion, - }).Info("received event from roomserver") - addsStateEvents, err := s.lookupStateEvents(msg.AddsStateEventIDs, ev) - if err != nil { - log.WithFields(log.Fields{ - "event": string(ev.JSON()), - log.ErrorKey: err, - "add": msg.AddsStateEventIDs, - "del": msg.RemovesStateEventIDs, - }).Panicf("roomserver output log: state event lookup failure") - } + addsStateEvents := msg.AddsState() - ev, err = s.updateStateEvent(ev) + ev, err := s.updateStateEvent(ev) if err != nil { return err } @@ -146,7 +132,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write event failure") return nil } - s.notifier.OnNewEvent(&ev, "", nil, types.PaginationToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0)) return nil } @@ -164,7 +150,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( }).Panicf("roomserver output log: write invite failure") return nil } - s.notifier.OnNewEvent(&msg.Event, "", nil, types.PaginationToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0)) return nil } @@ -185,63 +171,6 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( return nil } -// lookupStateEvents looks up the state events that are added by a new event. -func (s *OutputRoomEventConsumer) lookupStateEvents( - addsStateEventIDs []string, event gomatrixserverlib.HeaderedEvent, -) ([]gomatrixserverlib.HeaderedEvent, error) { - // Fast path if there aren't any new state events. - if len(addsStateEventIDs) == 0 { - return nil, nil - } - - // Fast path if the only state event added is the event itself. - if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { - return []gomatrixserverlib.HeaderedEvent{event}, nil - } - - // Check if this is re-adding a state events that we previously processed - // If we have previously received a state event it may still be in - // our event database. - result, err := s.db.Events(context.TODO(), addsStateEventIDs) - if err != nil { - return nil, err - } - missing := missingEventsFrom(result, addsStateEventIDs) - - // Check if event itself is being added. - for _, eventID := range missing { - if eventID == event.EventID() { - result = append(result, event) - break - } - } - missing = missingEventsFrom(result, addsStateEventIDs) - - if len(missing) == 0 { - return result, nil - } - - // At this point the missing events are neither the event itself nor are - // they present in our local database. Our only option is to fetch them - // from the roomserver using the query API. - eventReq := api.QueryEventsByIDRequest{EventIDs: missing} - var eventResp api.QueryEventsByIDResponse - if err := s.query.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { - return nil, err - } - - result = append(result, eventResp.Events...) - missing = missingEventsFrom(result, addsStateEventIDs) - - if len(missing) != 0 { - return nil, fmt.Errorf( - "missing %d state events IDs at event %q", len(missing), event.EventID(), - ) - } - - return result, nil -} - func (s *OutputRoomEventConsumer) updateStateEvent(event gomatrixserverlib.HeaderedEvent) (gomatrixserverlib.HeaderedEvent, error) { var stateKey string if event.StateKey() == nil { @@ -270,17 +199,3 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event gomatrixserverlib.Heade event.Event, err = event.SetUnsigned(prev) return event, err } - -func missingEventsFrom(events []gomatrixserverlib.HeaderedEvent, required []string) []string { - have := map[string]bool{} - for _, event := range events { - have[event.EventID()] = true - } - var missing []string - for _, eventID := range required { - if !have[eventID] { - missing = append(missing, eventID) - } - } - return missing -} diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 873ee9366..15add1b45 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -22,7 +22,7 @@ import ( "strconv" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" @@ -34,12 +34,13 @@ import ( type messagesReq struct { ctx context.Context db storage.Database - queryAPI api.RoomserverQueryAPI + rsAPI api.RoomserverInternalAPI federation *gomatrixserverlib.FederationClient cfg *config.Dendrite roomID string - from *types.PaginationToken - to *types.PaginationToken + from *types.TopologyToken + to *types.TopologyToken + fromStream *types.StreamingToken wasToProvided bool limit int backwardOrdering bool @@ -59,18 +60,23 @@ const defaultMessagesLimit = 10 func OnIncomingMessagesRequest( req *http.Request, db storage.Database, roomID string, federation *gomatrixserverlib.FederationClient, - queryAPI api.RoomserverQueryAPI, + rsAPI api.RoomserverInternalAPI, cfg *config.Dendrite, ) util.JSONResponse { var err error // Extract parameters from the request's URL. // Pagination tokens. - from, err := types.NewPaginationTokenFromString(req.URL.Query().Get("from")) + var fromStream *types.StreamingToken + from, err := types.NewTopologyTokenFromString(req.URL.Query().Get("from")) if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()), + fs, err2 := types.NewStreamTokenFromString(req.URL.Query().Get("from")) + fromStream = &fs + if err2 != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err2.Error()), + } } } @@ -88,10 +94,10 @@ func OnIncomingMessagesRequest( // Pagination tokens. To is optional, and its default value depends on the // direction ("b" or "f"). - var to *types.PaginationToken + var to types.TopologyToken wasToProvided := true if s := req.URL.Query().Get("to"); len(s) > 0 { - to, err = types.NewPaginationTokenFromString(s) + to, err = types.NewTopologyTokenFromString(s) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -135,12 +141,13 @@ func OnIncomingMessagesRequest( mReq := messagesReq{ ctx: req.Context(), db: db, - queryAPI: queryAPI, + rsAPI: rsAPI, federation: federation, cfg: cfg, roomID: roomID, - from: from, - to: to, + from: &from, + to: &to, + fromStream: fromStream, wasToProvided: wasToProvided, limit: limit, backwardOrdering: backwardOrdering, @@ -151,6 +158,7 @@ func OnIncomingMessagesRequest( util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") return jsonerror.InternalServerError() } + util.GetLogger(req.Context()).WithFields(logrus.Fields{ "from": from.String(), "to": to.String(), @@ -178,18 +186,27 @@ func OnIncomingMessagesRequest( // remote homeserver. func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, - end *types.PaginationToken, err error, + end types.TopologyToken, err error, ) { // Retrieve the events from the local database. - streamEvents, err := r.db.GetEventsInRange( - r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, - ) + var streamEvents []types.StreamEvent + if r.fromStream != nil { + toStream := r.to.StreamToken() + streamEvents, err = r.db.GetEventsInStreamingRange( + r.ctx, r.fromStream, &toStream, r.roomID, r.limit, r.backwardOrdering, + ) + } else { + streamEvents, err = r.db.GetEventsInTopologicalRange( + r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, + ) + } if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) return } var events []gomatrixserverlib.HeaderedEvent + util.GetLogger(r.ctx).WithField("start", start).WithField("end", end).Infof("Fetched %d events locally", len(streamEvents)) // There can be two reasons for streamEvents to be empty: either we've // reached the oldest event in the room (or the most recent one, depending @@ -206,16 +223,20 @@ func (r *messagesReq) retrieveEvents() ( // If we didn't get any event, we don't need to proceed any further. if len(events) == 0 { - return []gomatrixserverlib.ClientEvent{}, r.from, r.to, nil + return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil } // Sort the events to ensure we send them in the right order. - events = gomatrixserverlib.HeaderedReverseTopologicalOrdering(events) if r.backwardOrdering { // This reverses the array from old->new to new->old - sort.SliceStable(events, func(i, j int) bool { - return true - }) + reversed := func(in []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { + out := make([]gomatrixserverlib.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[len(in)-i-1] + } + return out + } + events = reversed(events) } // Convert all of the events into client events. @@ -226,45 +247,41 @@ func (r *messagesReq) retrieveEvents() ( // change the way topological positions are defined (as depth isn't the most // reliable way to define it), it would be easier and less troublesome to // only have to change it in one place, i.e. the database. - startPos, err := r.db.EventPositionInTopology( + start, end, err = r.getStartEnd(events) + + return clientEvents, start, end, err +} + +func (r *messagesReq) getStartEnd(events []gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { + start, err = r.db.EventPositionInTopology( r.ctx, events[0].EventID(), ) if err != nil { err = fmt.Errorf("EventPositionInTopology: for start event %s: %w", events[0].EventID(), err) return } - endPos, err := r.db.EventPositionInTopology( - r.ctx, events[len(events)-1].EventID(), - ) - if err != nil { - err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err) - return + if r.backwardOrdering && events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate { + // We've hit the beginning of the room so there's really nowhere else + // to go. This seems to fix Riot iOS from looping on /messages endlessly. + end = types.NewTopologyToken(0, 0) + } else { + end, err = r.db.EventPositionInTopology( + r.ctx, events[len(events)-1].EventID(), + ) + if err != nil { + err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err) + return + } + if r.backwardOrdering { + // A stream/topological position is a cursor located between two events. + // While they are identified in the code by the event on their right (if + // we consider a left to right chronological order), tokens need to refer + // to them by the event on their left, therefore we need to decrement the + // end position we send in the response if we're going backward. + end.Decrement() + } } - // Generate pagination tokens to send to the client using the positions - // retrieved previously. - start = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, startPos, 0, - ) - end = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, endPos, 0, - ) - - if r.backwardOrdering { - // A stream/topological position is a cursor located between two events. - // While they are identified in the code by the event on their right (if - // we consider a left to right chronological order), tokens need to refer - // to them by the event on their left, therefore we need to decrement the - // end position we send in the response if we're going backward. - end.PDUPosition-- - } - - // The lowest token value is 1, therefore we need to manually set it to that - // value if we're below it. - if end.PDUPosition < types.StreamPosition(1) { - end.PDUPosition = types.StreamPosition(1) - } - - return clientEvents, start, end, err + return } // handleEmptyEventsSlice handles the case where the initial request to the @@ -282,7 +299,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( // Check if we have backward extremities for this room. if len(backwardExtremities) > 0 { // If so, retrieve as much events as needed through backfilling. - events, err = r.backfill(backwardExtremities, r.limit) + events, err = r.backfill(r.roomID, backwardExtremities, r.limit) if err != nil { return } @@ -312,11 +329,11 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent // The condition in the SQL query is a strict "greater than" so // we need to check against to-1. streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition) - isSetLargeEnough = (r.to.PDUPosition-1 == streamPos) + isSetLargeEnough = (r.to.PDUPosition()-1 == streamPos) } } else { streamPos := types.StreamPosition(streamEvents[0].StreamPosition) - isSetLargeEnough = (r.from.PDUPosition-1 == streamPos) + isSetLargeEnough = (r.from.PDUPosition()-1 == streamPos) } } @@ -331,7 +348,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { var pdus []gomatrixserverlib.HeaderedEvent // Only ask the remote server for enough events to reach the limit. - pdus, err = r.backfill(backwardExtremities, r.limit-len(streamEvents)) + pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents)) if err != nil { return } @@ -342,10 +359,23 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent // Append the events ve previously retrieved locally. events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...) + sort.Sort(eventsByDepth(events)) return } +type eventsByDepth []gomatrixserverlib.HeaderedEvent + +func (e eventsByDepth) Len() int { + return len(e) +} +func (e eventsByDepth) Swap(i, j int) { + e[i], e[j] = e[j], e[i] +} +func (e eventsByDepth) Less(i, j int) bool { + return e[i].Depth() < e[j].Depth() +} + // backfill performs a backfill request over the federation on another // homeserver in the room. // See: https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid @@ -355,111 +385,53 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent // event, or if there is no remote homeserver to contact. // Returns an error if there was an issue with retrieving the list of servers in // the room or sending the request. -func (r *messagesReq) backfill(fromEventIDs []string, limit int) ([]gomatrixserverlib.HeaderedEvent, error) { - verReq := api.QueryRoomVersionForRoomRequest{RoomID: r.roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := r.queryAPI.QueryRoomVersionForRoom(r.ctx, &verReq, &verRes); err != nil { - return nil, err - } - - srvToBackfillFrom, err := r.serverToBackfillFrom(fromEventIDs) +func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]string, limit int) ([]gomatrixserverlib.HeaderedEvent, error) { + var res api.PerformBackfillResponse + err := r.rsAPI.PerformBackfill(context.Background(), &api.PerformBackfillRequest{ + RoomID: roomID, + BackwardsExtremities: backwardsExtremities, + Limit: limit, + ServerName: r.cfg.Matrix.ServerName, + }, &res) if err != nil { - return nil, fmt.Errorf("Cannot find server to backfill from: %w", err) + return nil, fmt.Errorf("PerformBackfill failed: %w", err) } + util.GetLogger(r.ctx).WithField("new_events", len(res.Events)).Info("Storing new events from backfill") - headered := make([]gomatrixserverlib.HeaderedEvent, 0) + // TODO: we should only be inserting events into the database from the roomserver's kafka output stream. + // Currently, this can race with live events for the room and cause problems. It's also just a bit unclear + // when you have multiple entry points to write events. - // If the roomserver responded with at least one server that isn't us, - // send it a request for backfill. - util.GetLogger(r.ctx).WithField("server", srvToBackfillFrom).WithField("limit", limit).Info("Backfilling from server") - txn, err := r.federation.Backfill( - r.ctx, srvToBackfillFrom, r.roomID, limit, fromEventIDs, - ) - if err != nil { - return nil, err - } - - for _, p := range txn.PDUs { - event, e := gomatrixserverlib.NewEventFromUntrustedJSON(p, verRes.RoomVersion) - if e != nil { - continue - } - headered = append(headered, event.Headered(verRes.RoomVersion)) - } - util.GetLogger(r.ctx).WithField("server", srvToBackfillFrom).WithField("new_events", len(headered)).Info("Storing new events from backfill") + // we have to order these by depth, starting with the lowest because otherwise the topology tokens + // will skip over events that have the same depth but different stream positions due to the query which is: + // - anything less than the depth OR + // - anything with the same depth and a lower stream position. + sort.Sort(eventsByDepth(res.Events)) // Store the events in the database, while marking them as unfit to show // up in responses to sync requests. - for i := range headered { - if _, err = r.db.WriteEvent( + for i := range res.Events { + _, err = r.db.WriteEvent( r.ctx, - &headered[i], + &res.Events[i], []gomatrixserverlib.HeaderedEvent{}, []string{}, []string{}, nil, true, - ); err != nil { + ) + if err != nil { return nil, err } } - return headered, nil -} - -func (r *messagesReq) serverToBackfillFrom(fromEventIDs []string) (gomatrixserverlib.ServerName, error) { - // Query the list of servers in the room when one of the backward extremities - // was sent. - var serversResponse api.QueryServersInRoomAtEventResponse - serversRequest := api.QueryServersInRoomAtEventRequest{ - RoomID: r.roomID, - EventID: fromEventIDs[0], - } - if err := r.queryAPI.QueryServersInRoomAtEvent(r.ctx, &serversRequest, &serversResponse); err != nil { - util.GetLogger(r.ctx).WithError(err).Warn("Failed to query servers in room at event, falling back to event sender") - // FIXME: We shouldn't be doing this but in situations where we have already backfilled once - // the query API doesn't work as backfilled events do not make it to the room server. - // This means QueryServersInRoomAtEvent returns an error as it doesn't have the event ID in question. - // We need to inject backfilled events into the room server and store them appropriately. - events, err := r.db.Events(r.ctx, fromEventIDs) - if err != nil { - return "", err - } - if len(events) == 0 { - // should be impossible as these event IDs are backwards extremities - return "", fmt.Errorf("backfill: missing backwards extremities, event IDs: %s", fromEventIDs) - } - // The rationale here is that the last event was unlikely to be sent by us, so poke the server who sent it. - // We shouldn't be doing this really, but as a heuristic it should work pretty well for now. - for _, e := range events { - _, srv, srverr := gomatrixserverlib.SplitID('@', e.Sender()) - if srverr != nil { - util.GetLogger(r.ctx).WithError(srverr).Warn("Failed to extract domain from event sender") - continue - } - if srv != r.cfg.Matrix.ServerName { - return srv, nil - } - } - // no valid events which have a remote server, fail. - return "", err + // we may have got more than the requested limit so resize now + events := res.Events + if len(events) > limit { + // last `limit` events + events = events[len(events)-limit:] } - // Use the first server from the response, except if that server is us. - // In that case, use the second one if the roomserver responded with - // enough servers. If not, use an empty string to prevent the backfill - // from happening as there's no server to direct the request towards. - // TODO: Be smarter at selecting the server to direct the request - // towards. - srvToBackfillFrom := serversResponse.Servers[0] - if srvToBackfillFrom == r.cfg.Matrix.ServerName { - if len(serversResponse.Servers) > 1 { - srvToBackfillFrom = serversResponse.Servers[1] - } else { - util.GetLogger(r.ctx).Info("Not enough servers to backfill from") - return "", nil - } - } - return srvToBackfillFrom, nil + return events, nil } // setToDefault returns the default value for the "to" query parameter of a @@ -471,18 +443,13 @@ func (r *messagesReq) serverToBackfillFrom(fromEventIDs []string) (gomatrixserve func setToDefault( ctx context.Context, db storage.Database, backwardOrdering bool, roomID string, -) (to *types.PaginationToken, err error) { +) (to types.TopologyToken, err error) { if backwardOrdering { // go 1 earlier than the first event so we correctly fetch the earliest event - to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. + to = types.NewTopologyToken(0, 0) } else { - var pos types.StreamPosition - pos, err = db.MaxTopologicalPosition(ctx, roomID) - if err != nil { - return - } - - to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, 0) + to, err = db.MaxTopologicalPosition(ctx, roomID) } return diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 9078b87ff..5744de05a 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -18,19 +18,17 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -const pathPrefixR0 = "/_matrix/client/r0" +const pathPrefixR0 = "/client/r0" // Setup configures the given mux with sync-server listeners // @@ -38,29 +36,23 @@ const pathPrefixR0 = "/_matrix/client/r0" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, - deviceDB devices.Database, federation *gomatrixserverlib.FederationClient, - queryAPI api.RoomserverQueryAPI, + publicAPIMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, + userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, + rsAPI api.RoomserverInternalAPI, cfg *config.Dendrite, ) { - r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() - - authData := auth.Data{ - AccountDB: nil, - DeviceDB: deviceDB, - AppServices: nil, - } + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() // TODO: Add AS support for all handlers below. - r0mux.Handle("/sync", common.MakeAuthAPI("sync", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/messages", common.MakeAuthAPI("room_messages", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) + r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, queryAPI, cfg) + return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, rsAPI, cfg) })).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index a3efd8d58..7b3bd6785 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -18,36 +18,114 @@ import ( "context" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/eduserver/cache" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - common.PartitionStorer + internal.PartitionStorer + // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) + // Events lookups a list of event by their event ID. + // Returns a list of events matching the requested IDs found in the database. + // If an event is not found in the database then it will be omitted from the list. + // Returns an error if there was a problem talking with the database. + // Does not include any transaction IDs in the returned events. Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) - WriteEvent(context.Context, *gomatrixserverlib.HeaderedEvent, []gomatrixserverlib.HeaderedEvent, []string, []string, *api.TransactionID, bool) (types.StreamPosition, error) + // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races + // when generating the sync stream position for this event. Returns the sync stream position for the inserted event. + // Returns an error if there was a problem inserting this event. + WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent, + addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error) + // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + // GetStateEventsForRoom fetches the state events for a given room. + // Returns an empty slice if no state events could be found for this room. + // Returns an error if there was an issue with the retrieval. GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) - SyncPosition(ctx context.Context) (types.PaginationToken, error) - IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) - CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) - GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error) + // SyncPosition returns the latest positions for syncing. + SyncPosition(ctx context.Context) (types.StreamingToken, error) + // IncrementalSync returns all the data needed in order to create an incremental + // sync response for the given user. Events returned will include any client + // transaction IDs associated with the given device. These transaction IDs come + // from when the device sent the event via an API that included a transaction + // ID. A response object must be provided for IncrementaSync to populate - it + // will not create one. + IncrementalSync(ctx context.Context, res *types.Response, device userapi.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + // CompleteSync returns a complete /sync API response for the given user. A response object + // must be provided for CompleteSync to populate - it will not create one. + CompleteSync(ctx context.Context, res *types.Response, device userapi.Device, numRecentEventsPerRoom int) (*types.Response, error) + // GetAccountDataInRange returns all account data for a given user inserted or + // updated between two given positions + // Returns a map following the format data[roomID] = []dataTypes + // If no data is retrieved, returns an empty map + // If there was an issue with the retrieval, returns an error + GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error) + // UpsertAccountData keeps track of new or updated account data, by saving the type + // of the new/updated data, and the user ID and room ID the data is related to (empty) + // room ID means the data isn't specific to any room) + // If no data with the given type, user ID and room ID exists in the database, + // creates a new row, else update the existing one + // Returns an error if there was an issue with the upsert UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (types.StreamPosition, error) + // AddInviteEvent stores a new invite event for a user. + // If the invite was successfully stored this returns the stream ID it was stored at. + // Returns an error if there was a problem communicating with the database. AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error) + // RetireInviteEvent removes an old invite event from the database. + // Returns an error if there was a problem communicating with the database. RetireInviteEvent(ctx context.Context, inviteEventID string) error + // SetTypingTimeoutCallback sets a callback function that is called right after + // a user is removed from the typing user list due to timeout. SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) + // AddTypingUser adds a typing user to the typing cache. + // Returns the newly calculated sync position for typing notifications. AddTypingUser(userID, roomID string, expireTime *time.Time) types.StreamPosition + // RemoveTypingUser removes a typing user from the typing cache. + // Returns the newly calculated sync position for typing notifications. RemoveTypingUser(userID, roomID string) types.StreamPosition - GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) - EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error) - EventsAtTopologicalPosition(ctx context.Context, roomID string, pos types.StreamPosition) ([]types.StreamEvent, error) - BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error) - MaxTopologicalPosition(ctx context.Context, roomID string) (types.StreamPosition, error) - StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent + // GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit. + GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + // EventPositionInTopology returns the depth and stream position of the given event. + EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) + // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. + BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) + // MaxTopologicalPosition returns the highest topological position for a given room. + MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) + // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and + // matches the streamevent.transactionID device then the transaction ID gets + // added to the unsigned section of the output event. + StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent + // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) + // AddSendToDevice increases the EDU position in the cache and returns the stream position. + AddSendToDevice() types.StreamPosition + // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists: + // - "events": a list of send-to-device events that should be included in the sync + // - "changes": a list of send-to-device events that should be updated in the database by + // CleanSendToDeviceUpdates + // - "deletions": a list of send-to-device events which have been confirmed as sent and + // can be deleted altogether by CleanSendToDeviceUpdates + // The token supplied should be the current requested sync token, e.g. from the "since" + // parameter. + SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error) + // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. + StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) + // CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the + // result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows + // SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after + // starting to wait for an incremental sync with timeout). + // The token supplied should be the current requested sync token, e.g. from the "since" + // parameter. + CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error) + // SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent. + SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error) } diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index d1e3b527f..67eb1e863 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -20,7 +20,9 @@ import ( "database/sql" "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -70,47 +72,41 @@ type accountDataStatements struct { selectMaxAccountDataIDStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(accountDataSchema) +func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountData, error) { + s := &accountDataStatements{} + _, err := db.Exec(accountDataSchema) if err != nil { - return + return nil, err } if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return + return nil, err } if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { - return + return nil, err } if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *accountDataStatements) insertAccountData( - ctx context.Context, +func (s *accountDataStatements) InsertAccountData( + ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos) return } -func (s *accountDataStatements) selectAccountDataInRange( +func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, userID string, - oldPos, newPos types.StreamPosition, + r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter, ) (data map[string][]string, err error) { data = make(map[string][]string) - // If both positions are the same, it means that the data was saved after the - // latest room event. In that case, we need to decrement the old position as - // it would prevent the SQL request from returning anything. - if oldPos == newPos { - oldPos-- - } - - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos, + rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)), accountDataEventFilter.Limit, @@ -118,7 +114,7 @@ func (s *accountDataStatements) selectAccountDataInRange( if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") for rows.Next() { var dataType string @@ -137,11 +133,11 @@ func (s *accountDataStatements) selectAccountDataInRange( return data, rows.Err() } -func (s *accountDataStatements) selectMaxAccountDataID( +func (s *accountDataStatements) SelectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := common.TxStmt(txn, s.selectMaxAccountDataIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 diff --git a/syncapi/storage/postgres/backward_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go similarity index 59% rename from syncapi/storage/postgres/backward_extremities_table.go rename to syncapi/storage/postgres/backwards_extremities_table.go index cb3629644..71569a108 100644 --- a/syncapi/storage/postgres/backward_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -18,28 +18,10 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/syncapi/storage/tables" ) -// The purpose of this table is to keep track of backwards extremities for a room. -// Backwards extremities are the earliest (DAG-wise) known events which we have -// the entire event JSON. These event IDs are used in federation requests to fetch -// even earlier events. -// -// We persist the previous event IDs as well, one per row, so when we do fetch even -// earlier events we can simply delete rows which referenced it. Consider the graph: -// A -// | Event C has 1 prev_event ID: A. -// B C -// |___| Event D has 2 prev_event IDs: B and C. -// | -// D -// The earliest known event we have is D, so this table has 2 rows. -// A backfill request gives us C but not B. We delete rows where prev_event=C. This -// still means that D is a backwards extremity as we do not have event B. However, event -// C is *also* a backwards extremity at this point as we do not have event A. Later, -// when we fetch event B, we delete rows where prev_event=B. This then removes D as -// a backwards extremity because there are no more rows with event_id=B. const backwardExtremitiesSchema = ` -- Stores output room events received from the roomserver. CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( @@ -49,7 +31,6 @@ CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( event_id TEXT NOT NULL, -- The prev_events for the last known event. This is used to update extremities. prev_event_id TEXT NOT NULL, - PRIMARY KEY(room_id, event_id, prev_event_id) ); ` @@ -60,7 +41,7 @@ const insertBackwardExtremitySQL = "" + " ON CONFLICT DO NOTHING" const selectBackwardExtremitiesForRoomSQL = "" + - "SELECT DISTINCT event_id FROM syncapi_backward_extremities WHERE room_id = $1" + "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1" const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" @@ -71,52 +52,54 @@ type backwardExtremitiesStatements struct { deleteBackwardExtremityStmt *sql.Stmt } -func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(backwardExtremitiesSchema) +func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { + s := &backwardExtremitiesStatements{} + _, err := db.Exec(backwardExtremitiesSchema) if err != nil { - return + return nil, err } if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return + return nil, err } if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return + return nil, err } if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *backwardExtremitiesStatements) insertsBackwardExtremity( +func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) return } -func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom( +func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( ctx context.Context, roomID string, -) (eventIDs []string, err error) { +) (bwExtrems map[string][]string, err error) { rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed") + bwExtrems = make(map[string][]string) for rows.Next() { var eID string - if err = rows.Scan(&eID); err != nil { + var prevEventID string + if err = rows.Scan(&eID, &prevEventID); err != nil { return } - - eventIDs = append(eventIDs, eID) + bwExtrems[eID] = append(bwExtrems[eID], prevEventID) } - return eventIDs, rows.Err() + return bwExtrems, rows.Err() } -func (s *backwardExtremitiesStatements) deleteBackwardExtremity( +func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index ab8f07b21..25906edb4 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -21,7 +21,9 @@ import ( "encoding/json" "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -103,44 +105,45 @@ type currentRoomStateStatements struct { selectStateEventStmt *sql.Stmt } -func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(currentRoomStateSchema) +func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { + s := ¤tRoomStateStatements{} + _, err := db.Exec(currentRoomStateSchema) if err != nil { - return + return nil, err } if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return + return nil, err } if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return + return nil, err } if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return + return nil, err } if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { - return + return nil, err } if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return + return nil, err } if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { - return + return nil, err } if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return + return nil, err } - return + return s, nil } -// JoinedMemberLists returns a map of room ID to a list of joined user IDs. -func (s *currentRoomStateStatements) selectJoinedUsers( +// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. +func (s *currentRoomStateStatements) SelectJoinedUsers( ctx context.Context, ) (map[string][]string, error) { rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") result := make(map[string][]string) for rows.Next() { @@ -157,18 +160,18 @@ func (s *currentRoomStateStatements) selectJoinedUsers( } // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. -func (s *currentRoomStateStatements) selectRoomIDsWithMembership( +func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( ctx context.Context, txn *sql.Tx, userID string, membership string, // nolint: unparam ) ([]string, error) { - stmt := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) rows, err := stmt.QueryContext(ctx, userID, membership) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed") var result []string for rows.Next() { @@ -181,12 +184,12 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership( return result, rows.Err() } -// CurrentState returns all the current state events for the given room. -func (s *currentRoomStateStatements) selectCurrentState( +// SelectCurrentState returns all the current state events for the given room. +func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) ([]gomatrixserverlib.HeaderedEvent, error) { - stmt := common.TxStmt(txn, s.selectCurrentStateStmt) + stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(stateFilter.Senders), pq.StringArray(stateFilter.NotSenders), @@ -198,20 +201,20 @@ func (s *currentRoomStateStatements) selectCurrentState( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectCurrentState: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectCurrentState: rows.close() failed") return rowsToEvents(rows) } -func (s *currentRoomStateStatements) deleteRoomStateByEventID( +func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - stmt := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) _, err := stmt.ExecContext(ctx, eventID) return err } -func (s *currentRoomStateStatements) upsertRoomState( +func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, ) error { @@ -229,7 +232,7 @@ func (s *currentRoomStateStatements) upsertRoomState( } // upsert state event - stmt := common.TxStmt(txn, s.upsertRoomStateStmt) + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) _, err = stmt.ExecContext( ctx, event.RoomID(), @@ -245,15 +248,15 @@ func (s *currentRoomStateStatements) upsertRoomState( return err } -func (s *currentRoomStateStatements) selectEventsWithEventIDs( +func (s *currentRoomStateStatements) SelectEventsWithEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { - stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventsWithEventIDsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") return rowsToStreamEvents(rows) } @@ -274,7 +277,7 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) { return result, rows.Err() } -func (s *currentRoomStateStatements) selectStateEvent( +func (s *currentRoomStateStatements) SelectStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { stmt := s.selectStateEventStmt diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index ca0c64fb9..5031d64e5 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -20,7 +20,9 @@ import ( "database/sql" "encoding/json" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -66,28 +68,29 @@ type inviteEventsStatements struct { selectMaxInviteIDStmt *sql.Stmt } -func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteEventsSchema) +func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteEventsStatements{} + _, err := db.Exec(inviteEventsSchema) if err != nil { - return + return nil, err } if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return + return nil, err } if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return + return nil, err } if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return + return nil, err } if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *inviteEventsStatements) insertInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, +func (s *inviteEventsStatements) InsertInviteEvent( + ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { var headeredJSON []byte headeredJSON, err = json.Marshal(inviteEvent) @@ -105,7 +108,7 @@ func (s *inviteEventsStatements) insertInviteEvent( return } -func (s *inviteEventsStatements) deleteInviteEvent( +func (s *inviteEventsStatements) DeleteInviteEvent( ctx context.Context, inviteEventID string, ) error { _, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID) @@ -114,15 +117,15 @@ func (s *inviteEventsStatements) deleteInviteEvent( // selectInviteEventsInRange returns a map of room ID to invite event for the // active invites for the target user ID in the supplied range. -func (s *inviteEventsStatements) selectInviteEventsInRange( - ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition, +func (s *inviteEventsStatements) SelectInviteEventsInRange( + ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, ) (map[string]gomatrixserverlib.HeaderedEvent, error) { - stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) - rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos) + stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) + rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") result := map[string]gomatrixserverlib.HeaderedEvent{} for rows.Next() { var ( @@ -143,11 +146,11 @@ func (s *inviteEventsStatements) selectInviteEventsInRange( return result, rows.Err() } -func (s *inviteEventsStatements) selectMaxInviteID( +func (s *inviteEventsStatements) SelectMaxInviteID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := common.TxStmt(txn, s.selectMaxInviteIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 0b53dfa9e..f01b2eabd 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -21,11 +21,13 @@ import ( "encoding/json" "sort" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -120,46 +122,47 @@ type outputRoomEventsStatements struct { selectStateInRangeStmt *sql.Stmt } -func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(outputRoomEventsSchema) +func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { + s := &outputRoomEventsStatements{} + _, err := db.Exec(outputRoomEventsSchema) if err != nil { - return + return nil, err } if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { - return + return nil, err } if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil { - return + return nil, err } if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil { - return + return nil, err } if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil { - return + return nil, err } if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil { - return + return nil, err } if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil { - return + return nil, err } if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { - return + return nil, err } - return + return s, nil } // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. -func (s *outputRoomEventsStatements) selectStateInRange( - ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, +func (s *outputRoomEventsStatements) SelectStateInRange( + ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { - stmt := common.TxStmt(txn, s.selectStateInRangeStmt) + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( - ctx, oldPos, newPos, + ctx, r.Low(), r.High(), pq.StringArray(stateFilter.Senders), pq.StringArray(stateFilter.NotSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), @@ -170,7 +173,7 @@ func (s *outputRoomEventsStatements) selectStateInRange( if err != nil { return nil, nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectStateInRange: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectStateInRange: rows.close() failed") // Fetch all the state change events for all rooms between the two positions then loop each event and: // - Keep a cache of the event by ID (99% of state change events are for the event itself) // - For each room ID, build up an array of event IDs which represents cumulative adds/removes @@ -196,8 +199,8 @@ func (s *outputRoomEventsStatements) selectStateInRange( // since it'll just mark the event as not being needed. if len(addIDs) < len(delIDs) { log.WithFields(log.Fields{ - "since": oldPos, - "current": newPos, + "since": r.From, + "current": r.To, "adds": addIDs, "dels": delIDs, }).Warn("StateBetween: ignoring deleted state") @@ -233,11 +236,11 @@ func (s *outputRoomEventsStatements) selectStateInRange( // MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied, // then this function should only ever be used at startup, as it will race with inserting events if it is // done afterwards. If there are no inserted events, 0 is returned. -func (s *outputRoomEventsStatements) selectMaxEventID( +func (s *outputRoomEventsStatements) SelectMaxEventID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := common.TxStmt(txn, s.selectMaxEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 @@ -247,7 +250,7 @@ func (s *outputRoomEventsStatements) selectMaxEventID( // InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position // of the inserted event. -func (s *outputRoomEventsStatements) insertEvent( +func (s *outputRoomEventsStatements) InsertEvent( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool, @@ -273,7 +276,7 @@ func (s *outputRoomEventsStatements) insertEvent( return } - stmt := common.TxStmt(txn, s.insertEventStmt) + stmt := sqlutil.TxStmt(txn, s.insertEventStmt) err = stmt.QueryRowContext( ctx, event.RoomID(), @@ -294,22 +297,22 @@ func (s *outputRoomEventsStatements) insertEvent( // selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude // from sync. -func (s *outputRoomEventsStatements) selectRecentEvents( +func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, fromPos, toPos types.StreamPosition, limit int, + roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool, ) ([]types.StreamEvent, error) { var stmt *sql.Stmt if onlySyncEvents { - stmt = common.TxStmt(txn, s.selectRecentEventsForSyncStmt) + stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) } else { - stmt = common.TxStmt(txn, s.selectRecentEventsStmt) + stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } - rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") events, err := rowsToStreamEvents(rows) if err != nil { return nil, err @@ -327,16 +330,16 @@ func (s *outputRoomEventsStatements) selectRecentEvents( // selectEarlyEvents returns the earliest events in the given room, starting // from a given position, up to a maximum of 'limit'. -func (s *outputRoomEventsStatements) selectEarlyEvents( +func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, - roomID string, fromPos, toPos types.StreamPosition, limit int, + roomID string, r types.Range, limit int, ) ([]types.StreamEvent, error) { - stmt := common.TxStmt(txn, s.selectEarlyEventsStmt) - rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) + rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") events, err := rowsToStreamEvents(rows) if err != nil { return nil, err @@ -352,15 +355,15 @@ func (s *outputRoomEventsStatements) selectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. -func (s *outputRoomEventsStatements) selectEvents( +func (s *outputRoomEventsStatements) SelectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { - stmt := common.TxStmt(txn, s.selectEventsStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") return rowsToStreamEvents(rows) } diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 280d4ec39..1ab3a1dc2 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -18,8 +18,8 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" - + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -32,39 +32,44 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( -- The place of the event in the room's topology. This can usually be determined -- from the event's depth. topological_position BIGINT NOT NULL, + stream_position BIGINT NOT NULL, -- The 'room_id' key for the event. room_id TEXT NOT NULL ); -- The topological order will be used in events selection and ordering -CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id); +CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id); ` const insertEventInTopologySQL = "" + - "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" + - " VALUES ($1, $2, $3)" + - " ON CONFLICT (topological_position, room_id) DO UPDATE SET event_id = $1" + "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (topological_position, stream_position, room_id) DO UPDATE SET event_id = $1" const selectEventIDsInRangeASCSQL = "" + "SELECT event_id FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + - " ORDER BY topological_position ASC LIMIT $4" + " WHERE room_id = $1 AND (" + + "(topological_position > $2 AND topological_position < $3) OR" + + "(topological_position = $4 AND stream_position <= $5)" + + ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" const selectEventIDsInRangeDESCSQL = "" + "SELECT event_id FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + - " ORDER BY topological_position DESC LIMIT $4" + " WHERE room_id = $1 AND (" + + "(topological_position > $2 AND topological_position < $3) OR" + + "(topological_position = $4 AND stream_position <= $5)" + + ") ORDER BY topological_position DESC, stream_position DESC LIMIT $6" const selectPositionInTopologySQL = "" + - "SELECT topological_position FROM syncapi_output_room_events_topology" + + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" + // Select the max topological position for the room, then sort by stream position and take the highest, + // returning both topological and stream positions. const selectMaxPositionInTopologySQL = "" + - "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1" - -const selectEventIDsFromPositionSQL = "" + - "SELECT event_id FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 AND topological_position = $2" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + + " WHERE topological_position=(" + + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + + ") ORDER BY stream_position DESC LIMIT 1" type outputRoomEventsTopologyStatements struct { insertEventInTopologyStmt *sql.Stmt @@ -72,51 +77,48 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt - selectEventIDsFromPositionStmt *sql.Stmt } -func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(outputRoomEventsTopologySchema) +func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { + s := &outputRoomEventsTopologyStatements{} + _, err := db.Exec(outputRoomEventsTopologySchema) if err != nil { - return + return nil, err } if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return + return nil, err } if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return + return nil, err } if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return + return nil, err } if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return + return nil, err } if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return + return nil, err } - if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil { - return - } - return + return s, nil } -// insertEventInTopology inserts the given event in the room's topology, based +// InsertEventInTopology inserts the given event in the room's topology, based // on the event's depth. -func (s *outputRoomEventsTopologyStatements) insertEventInTopology( - ctx context.Context, event *gomatrixserverlib.HeaderedEvent, +func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( + ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { _, err = s.insertEventInTopologyStmt.ExecContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), + ctx, event.EventID(), event.Depth(), event.RoomID(), pos, ) return } -// selectEventIDsInRange selects the IDs of events which positions are within a +// SelectEventIDsInRange selects the IDs of events which positions are within a // given range in a given room's topological order. // Returns an empty slice if no events match the given range. -func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( - ctx context.Context, roomID string, fromPos, toPos types.StreamPosition, +func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( + ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool, ) (eventIDs []string, err error) { // Decide on the selection's order according to whether chronological order @@ -129,14 +131,14 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( } // Query the event IDs. - rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit) if err == sql.ErrNoRows { // If no event matched the request, return an empty slice. return []string{}, nil } else if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectEventIDsInRange: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectEventIDsInRange: rows.close() failed") // Return the IDs. var eventID string @@ -150,43 +152,18 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( return eventIDs, rows.Err() } -// selectPositionInTopology returns the position of a given event in the +// SelectPositionInTopology returns the position of a given event in the // topology of the room it belongs to. -func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( - ctx context.Context, eventID string, -) (pos types.StreamPosition, err error) { - err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos) +func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( + ctx context.Context, txn *sql.Tx, eventID string, +) (pos, spos types.StreamPosition, err error) { + err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) return } -func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( - ctx context.Context, roomID string, -) (pos types.StreamPosition, err error) { - err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos) +func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( + ctx context.Context, txn *sql.Tx, roomID string, +) (pos types.StreamPosition, spos types.StreamPosition, err error) { + err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } - -// selectEventIDsFromPosition returns the IDs of all events that have a given -// position in the topology of a given room. -func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition( - ctx context.Context, roomID string, pos types.StreamPosition, -) (eventIDs []string, err error) { - // Query the event IDs. - rows, err := s.selectEventIDsFromPositionStmt.QueryContext(ctx, roomID, pos) - if err == sql.ErrNoRows { - // If no event matched the request, return an empty slice. - return []string{}, nil - } else if err != nil { - return - } - defer common.CloseAndLogIfError(ctx, rows, "selectEventIDsFromPosition: rows.close() failed") - // Return the IDs. - var eventID string - for rows.Next() { - if err = rows.Scan(&eventID); err != nil { - return - } - eventIDs = append(eventIDs, eventID) - } - return eventIDs, rows.Err() -} diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go new file mode 100644 index 000000000..07af9ad6b --- /dev/null +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -0,0 +1,172 @@ +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const sendToDeviceSchema = ` +CREATE SEQUENCE IF NOT EXISTS syncapi_send_to_device_id; + +-- Stores send-to-device messages. +CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( + -- The ID that uniquely identifies this message. + id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_send_to_device_id'), + -- The user ID to send the message to. + user_id TEXT NOT NULL, + -- The device ID to send the message to. + device_id TEXT NOT NULL, + -- The event content JSON. + content TEXT NOT NULL, + -- The token that was supplied to the /sync at the time that this + -- message was included in a sync response, or NULL if we haven't + -- included it in a /sync response yet. + sent_by_token TEXT +); +` + +const insertSendToDeviceMessageSQL = ` + INSERT INTO syncapi_send_to_device (user_id, device_id, content) + VALUES ($1, $2, $3) +` + +const countSendToDeviceMessagesSQL = ` + SELECT COUNT(*) + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 +` + +const selectSendToDeviceMessagesSQL = ` + SELECT id, user_id, device_id, content, sent_by_token + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 + ORDER BY id DESC +` + +const updateSentSendToDeviceMessagesSQL = ` + UPDATE syncapi_send_to_device SET sent_by_token = $1 + WHERE id = ANY($2) +` + +const deleteSendToDeviceMessagesSQL = ` + DELETE FROM syncapi_send_to_device WHERE id = ANY($1) +` + +type sendToDeviceStatements struct { + insertSendToDeviceMessageStmt *sql.Stmt + countSendToDeviceMessagesStmt *sql.Stmt + selectSendToDeviceMessagesStmt *sql.Stmt + updateSentSendToDeviceMessagesStmt *sql.Stmt + deleteSendToDeviceMessagesStmt *sql.Stmt +} + +func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { + s := &sendToDeviceStatements{} + _, err := db.Exec(sendToDeviceSchema) + if err != nil { + return nil, err + } + if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { + return nil, err + } + if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *sendToDeviceStatements) InsertSendToDeviceMessage( + ctx context.Context, txn *sql.Tx, userID, deviceID, content string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return +} + +func (s *sendToDeviceStatements) CountSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (count int, err error) { + row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + if err = row.Scan(&count); err != nil { + return + } + return count, nil +} + +func (s *sendToDeviceStatements) SelectSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (events []types.SendToDeviceEvent, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") + + for rows.Next() { + var id types.SendToDeviceNID + var userID, deviceID, content string + var sentByToken *string + if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil { + return + } + event := types.SendToDeviceEvent{ + ID: id, + UserID: userID, + DeviceID: deviceID, + } + if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { + return + } + if sentByToken != nil { + if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil { + event.SentByToken = &token + } + } + events = append(events, event) + } + + return events, rows.Err() +} + +func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, +) (err error) { + _, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) + return +} + +func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, +) (err error) { + _, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) + return +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 7fd75f066..573586cc7 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -16,1122 +16,72 @@ package postgres import ( - "context" "database/sql" - "encoding/json" - "fmt" - "time" - - "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/api" // Import the postgres database driver. _ "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/shared" ) -type stateDelta struct { - roomID string - stateEvents []gomatrixserverlib.HeaderedEvent - membership string - // The PDU stream position of the latest membership event for this user, if applicable. - // Can be 0 if there is no membership event in this delta. - membershipPos types.StreamPosition -} - // SyncServerDatasource represents a sync server datasource which manages // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { + shared.Database db *sql.DB - common.PartitionOffsetStatements - accountData accountDataStatements - events outputRoomEventsStatements - roomstate currentRoomStateStatements - invites inviteEventsStatements - eduCache *cache.EDUCache - topology outputRoomEventsTopologyStatements - backwardExtremities backwardExtremitiesStatements + sqlutil.PartitionOffsetStatements } -// NewSyncServerDatasource creates a new sync server database -func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) { +// NewDatabase creates a new sync server database +func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error - if d.db, err = sqlutil.Open("postgres", dbDataSourceName); err != nil { + if d.db, err = sqlutil.Open("postgres", dbDataSourceName, dbProperties); err != nil { return nil, err } if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { return nil, err } - if err = d.accountData.prepare(d.db); err != nil { + accountData, err := NewPostgresAccountDataTable(d.db) + if err != nil { return nil, err } - if err = d.events.prepare(d.db); err != nil { + events, err := NewPostgresEventsTable(d.db) + if err != nil { return nil, err } - if err := d.roomstate.prepare(d.db); err != nil { + currState, err := NewPostgresCurrentRoomStateTable(d.db) + if err != nil { return nil, err } - if err := d.invites.prepare(d.db); err != nil { + invites, err := NewPostgresInvitesTable(d.db) + if err != nil { return nil, err } - if err := d.topology.prepare(d.db); err != nil { + topology, err := NewPostgresTopologyTable(d.db) + if err != nil { return nil, err } - if err := d.backwardExtremities.prepare(d.db); err != nil { + backwardExtremities, err := NewPostgresBackwardsExtremitiesTable(d.db) + if err != nil { return nil, err } - d.eduCache = cache.New() + sendToDevice, err := NewPostgresSendToDeviceTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + Invites: invites, + AccountData: accountData, + OutputEvents: events, + Topology: topology, + CurrentRoomState: currState, + BackwardExtremities: backwardExtremities, + SendToDevice: sendToDevice, + SendToDeviceWriter: sqlutil.NewTransactionWriter(), + EDUCache: cache.New(), + } return &d, nil } - -// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. -func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { - return d.roomstate.selectJoinedUsers(ctx) -} - -// Events lookups a list of event by their event ID. -// Returns a list of events matching the requested IDs found in the database. -// If an event is not found in the database then it will be omitted from the list. -// Returns an error if there was a problem talking with the database. -// Does not include any transaction IDs in the returned events. -func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) - if err != nil { - return nil, err - } - - // We don't include a device here as we only include transaction IDs in - // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil -} - -// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of -// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table -// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. -func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { - if err := d.backwardExtremities.deleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { - return err - } - - // Check if we have all of the event's previous events. If an event is - // missing, add it to the room's backward extremities. - prevEvents, err := d.events.selectEvents(ctx, txn, ev.PrevEventIDs()) - if err != nil { - return err - } - var found bool - for _, eID := range ev.PrevEventIDs() { - found = false - for _, prevEv := range prevEvents { - if eID == prevEv.EventID() { - found = true - } - } - - // If the event is missing, consider it a backward extremity. - if !found { - if err = d.backwardExtremities.insertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil { - return err - } - } - } - - return nil -} - -// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races -// when generating the sync stream position for this event. Returns the sync stream position for the inserted event. -// Returns an error if there was a problem inserting this event. -func (d *SyncServerDatasource) WriteEvent( - ctx context.Context, - ev *gomatrixserverlib.HeaderedEvent, - addStateEvents []gomatrixserverlib.HeaderedEvent, - addStateEventIDs, removeStateEventIDs []string, - transactionID *api.TransactionID, excludeFromSync bool, -) (pduPosition types.StreamPosition, returnErr error) { - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - pos, err := d.events.insertEvent( - ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, - ) - if err != nil { - return err - } - pduPosition = pos - - if err = d.topology.insertEventInTopology(ctx, ev); err != nil { - return err - } - - if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { - return err - } - - if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { - // Nothing to do, the event may have just been a message event. - return nil - } - - return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition) - }) - - return pduPosition, returnErr -} - -func (d *SyncServerDatasource) updateRoomState( - ctx context.Context, txn *sql.Tx, - removedEventIDs []string, - addedEvents []gomatrixserverlib.HeaderedEvent, - pduPosition types.StreamPosition, -) error { - // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. - for _, eventID := range removedEventIDs { - if err := d.roomstate.deleteRoomStateByEventID(ctx, txn, eventID); err != nil { - return err - } - } - - for _, event := range addedEvents { - if event.StateKey() == nil { - // ignore non state events - continue - } - var membership *string - if event.Type() == "m.room.member" { - value, err := event.Membership() - if err != nil { - return err - } - membership = &value - } - if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { - return err - } - } - - return nil -} - -// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key -// If no event could be found, returns nil -// If there was an issue during the retrieval, returns an error -func (d *SyncServerDatasource) GetStateEvent( - ctx context.Context, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { - return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) -} - -// GetStateEventsForRoom fetches the state events for a given room. -// Returns an empty slice if no state events could be found for this room. -// Returns an error if there was an issue with the retrieval. -func (d *SyncServerDatasource) GetStateEventsForRoom( - ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, -) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilter) - return err - }) - return -} - -// GetEventsInRange retrieves all of the events on a given ordering using the -// given extremities and limit. -func (d *SyncServerDatasource) GetEventsInRange( - ctx context.Context, - from, to *types.PaginationToken, - roomID string, limit int, - backwardOrdering bool, -) (events []types.StreamEvent, err error) { - // If the pagination token's type is types.PaginationTokenTypeTopology, the - // events must be retrieved from the rooms' topology table rather than the - // table contaning the syncapi server's whole stream of events. - if from.Type == types.PaginationTokenTypeTopology { - // Determine the backward and forward limit, i.e. the upper and lower - // limits to the selection in the room's topology, from the direction. - var backwardLimit, forwardLimit types.StreamPosition - if backwardOrdering { - // Backward ordering is antichronological (latest event to oldest - // one). - backwardLimit = to.PDUPosition - forwardLimit = from.PDUPosition - } else { - // Forward ordering is chronological (oldest event to latest one). - backwardLimit = from.PDUPosition - forwardLimit = to.PDUPosition - } - - // Select the event IDs from the defined range. - var eIDs []string - eIDs, err = d.topology.selectEventIDsInRange( - ctx, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering, - ) - if err != nil { - return - } - - // Retrieve the events' contents using their IDs. - events, err = d.events.selectEvents(ctx, nil, eIDs) - return - } - - // If the pagination token's type is types.PaginationTokenTypeStream, the - // events must be retrieved from the table contaning the syncapi server's - // whole stream of events. - - if backwardOrdering { - // When using backward ordering, we want the most recent events first. - if events, err = d.events.selectRecentEvents( - ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, - ); err != nil { - return - } - } else { - // When using forward ordering, we want the least recent events first. - if events, err = d.events.selectEarlyEvents( - ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, - ); err != nil { - return - } - } - - return -} - -// SyncPosition returns the latest positions for syncing. -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) { - return d.syncPositionTx(ctx, nil) -} - -// BackwardExtremitiesForRoom returns the event IDs of all of the backward -// extremities we know of for a given room. -func (d *SyncServerDatasource) BackwardExtremitiesForRoom( - ctx context.Context, roomID string, -) (backwardExtremities []string, err error) { - return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, roomID) -} - -// MaxTopologicalPosition returns the highest topological position for a given -// room. -func (d *SyncServerDatasource) MaxTopologicalPosition( - ctx context.Context, roomID string, -) (types.StreamPosition, error) { - return d.topology.selectMaxPositionInTopology(ctx, roomID) -} - -// EventsAtTopologicalPosition returns all of the events matching a given -// position in the topology of a given room. -func (d *SyncServerDatasource) EventsAtTopologicalPosition( - ctx context.Context, roomID string, pos types.StreamPosition, -) ([]types.StreamEvent, error) { - eIDs, err := d.topology.selectEventIDsFromPosition(ctx, roomID, pos) - if err != nil { - return nil, err - } - - return d.events.selectEvents(ctx, nil, eIDs) -} - -func (d *SyncServerDatasource) EventPositionInTopology( - ctx context.Context, eventID string, -) (types.StreamPosition, error) { - return d.topology.selectPositionInTopology(ctx, eventID) -} - -// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. -func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { - return d.syncStreamPositionTx(ctx, nil) -} - -func (d *SyncServerDatasource) syncStreamPositionTx( - ctx context.Context, txn *sql.Tx, -) (types.StreamPosition, error) { - maxID, err := d.events.selectMaxEventID(ctx, txn) - if err != nil { - return 0, err - } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) - if err != nil { - return 0, err - } - if maxAccountDataID > maxID { - maxID = maxAccountDataID - } - maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) - if err != nil { - return 0, err - } - if maxInviteID > maxID { - maxID = maxInviteID - } - return types.StreamPosition(maxID), nil -} - -func (d *SyncServerDatasource) syncPositionTx( - ctx context.Context, txn *sql.Tx, -) (sp types.PaginationToken, err error) { - - maxEventID, err := d.events.selectMaxEventID(ctx, txn) - if err != nil { - return sp, err - } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) - if err != nil { - return sp, err - } - if maxAccountDataID > maxEventID { - maxEventID = maxAccountDataID - } - maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) - if err != nil { - return sp, err - } - if maxInviteID > maxEventID { - maxEventID = maxInviteID - } - sp.PDUPosition = types.StreamPosition(maxEventID) - sp.EDUTypingPosition = types.StreamPosition(d.eduCache.GetLatestSyncPosition()) - return -} - -// addPDUDeltaToResponse adds all PDU deltas to a sync response. -// IDs of all rooms the user joined are returned so EDU deltas can be added for them. -func (d *SyncServerDatasource) addPDUDeltaToResponse( - ctx context.Context, - device authtypes.Device, - fromPos, toPos types.StreamPosition, - numRecentEventsPerRoom int, - wantFullState bool, - res *types.Response, -) (joinedRoomIDs []string, err error) { - txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) - if err != nil { - return nil, err - } - var succeeded bool - defer func() { - txerr := common.EndTransaction(txn, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() - - stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request - - // Work out which rooms to return in the response. This is done by getting not only the currently - // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions. - // This works out what the 'state' key should be for each room as well as which membership block - // to put the room into. - var deltas []stateDelta - if !wantFullState { - deltas, joinedRoomIDs, err = d.getStateDeltas( - ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilter, - ) - } else { - deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync( - ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilter, - ) - } - if err != nil { - return nil, err - } - - for _, delta := range deltas { - err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) - if err != nil { - return nil, err - } - } - - // TODO: This should be done in getStateDeltas - if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil { - return nil, err - } - - succeeded = true - return joinedRoomIDs, nil -} - -// addTypingDeltaToResponse adds all typing notifications to a sync response -// since the specified position. -func (d *SyncServerDatasource) addTypingDeltaToResponse( - since types.PaginationToken, - joinedRoomIDs []string, - res *types.Response, -) error { - var jr types.JoinResponse - var ok bool - var err error - for _, roomID := range joinedRoomIDs { - if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter( - roomID, int64(since.EDUTypingPosition), - ); updated { - ev := gomatrixserverlib.ClientEvent{ - Type: gomatrixserverlib.MTyping, - } - ev.Content, err = json.Marshal(map[string]interface{}{ - "user_ids": typingUsers, - }) - if err != nil { - return err - } - - if jr, ok = res.Rooms.Join[roomID]; !ok { - jr = *types.NewJoinResponse() - } - jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) - res.Rooms.Join[roomID] = jr - } - } - return nil -} - -// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if -// the positions of that type are not equal in fromPos and toPos. -func (d *SyncServerDatasource) addEDUDeltaToResponse( - fromPos, toPos types.PaginationToken, - joinedRoomIDs []string, - res *types.Response, -) (err error) { - - if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { - err = d.addTypingDeltaToResponse( - fromPos, joinedRoomIDs, res, - ) - } - - return -} - -// IncrementalSync returns all the data needed in order to create an incremental -// sync response for the given user. Events returned will include any client -// transaction IDs associated with the given device. These transaction IDs come -// from when the device sent the event via an API that included a transaction -// ID. -func (d *SyncServerDatasource) IncrementalSync( - ctx context.Context, - device authtypes.Device, - fromPos, toPos types.PaginationToken, - numRecentEventsPerRoom int, - wantFullState bool, -) (*types.Response, error) { - nextBatchPos := fromPos.WithUpdates(toPos) - res := types.NewResponse(nextBatchPos) - - var joinedRoomIDs []string - var err error - if fromPos.PDUPosition != toPos.PDUPosition || wantFullState { - joinedRoomIDs, err = d.addPDUDeltaToResponse( - ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res, - ) - } else { - joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( - ctx, nil, device.UserID, gomatrixserverlib.Join, - ) - } - if err != nil { - return nil, err - } - - err = d.addEDUDeltaToResponse( - fromPos, toPos, joinedRoomIDs, res, - ) - if err != nil { - return nil, err - } - - return res, nil -} - -// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed -// to it. It returns toPos and joinedRoomIDs for use of adding EDUs. -func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( - ctx context.Context, - userID string, - numRecentEventsPerRoom int, -) ( - res *types.Response, - toPos types.PaginationToken, - joinedRoomIDs []string, - err error, -) { - // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have - // a consistent view of the database throughout. This includes extracting the sync position. - // This does have the unfortunate side-effect that all the matrixy logic resides in this function, - // but it's better to not hide the fact that this is being done in a transaction. - txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) - if err != nil { - return - } - var succeeded bool - defer func() { - txerr := common.EndTransaction(txn, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() - - // Get the current sync position which we will base the sync response on. - toPos, err = d.syncPositionTx(ctx, txn) - if err != nil { - return - } - - res = types.NewResponse(toPos) - - // Extract room state and recent events for all rooms the user is joined to. - joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return - } - - stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request - - // Build up a /sync response. Add joined rooms. - for _, roomID := range joinedRoomIDs { - var stateEvents []gomatrixserverlib.HeaderedEvent - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilter) - if err != nil { - return - } - // TODO: When filters are added, we may need to call this multiple times to get enough events. - // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - var recentStreamEvents []types.StreamEvent - recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, - numRecentEventsPerRoom, true, true, - ) - if err != nil { - return - } - - // Retrieve the backward topology position, i.e. the position of the - // oldest event in the room's topology. - var backwardTopologyPos types.StreamPosition - backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) - if backwardTopologyPos-1 <= 0 { - backwardTopologyPos = types.StreamPosition(1) - } else { - backwardTopologyPos-- - } - - // We don't include a device here as we don't need to send down - // transaction IDs for complete syncs - recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) - stateEvents = removeDuplicates(stateEvents, recentEvents) - jr := types.NewJoinResponse() - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, - ).String() - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = true - jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Join[roomID] = *jr - } - - if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { - return - } - - succeeded = true - return res, toPos, joinedRoomIDs, err -} - -// CompleteSync returns a complete /sync API response for the given user. -func (d *SyncServerDatasource) CompleteSync( - ctx context.Context, userID string, numRecentEventsPerRoom int, -) (*types.Response, error) { - res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( - ctx, userID, numRecentEventsPerRoom, - ) - if err != nil { - return nil, err - } - - // Use a zero value SyncPosition for fromPos so all EDU states are added. - err = d.addEDUDeltaToResponse( - types.PaginationToken{}, toPos, joinedRoomIDs, res, - ) - if err != nil { - return nil, err - } - - return res, nil -} - -var txReadOnlySnapshot = sql.TxOptions{ - // Set the isolation level so that we see a snapshot of the database. - // In PostgreSQL repeatable read transactions will see a snapshot taken - // at the first query, and since the transaction is read-only it can't - // run into any serialisation errors. - // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ - Isolation: sql.LevelRepeatableRead, - ReadOnly: true, -} - -// GetAccountDataInRange returns all account data for a given user inserted or -// updated between two given positions -// Returns a map following the format data[roomID] = []dataTypes -// If no data is retrieved, returns an empty map -// If there was an issue with the retrieval, returns an error -func (d *SyncServerDatasource) GetAccountDataInRange( - ctx context.Context, userID string, oldPos, newPos types.StreamPosition, - accountDataFilterPart *gomatrixserverlib.EventFilter, -) (map[string][]string, error) { - return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) -} - -// UpsertAccountData keeps track of new or updated account data, by saving the type -// of the new/updated data, and the user ID and room ID the data is related to (empty) -// room ID means the data isn't specific to any room) -// If no data with the given type, user ID and room ID exists in the database, -// creates a new row, else update the existing one -// Returns an error if there was an issue with the upsert -func (d *SyncServerDatasource) UpsertAccountData( - ctx context.Context, userID, roomID, dataType string, -) (types.StreamPosition, error) { - return d.accountData.insertAccountData(ctx, userID, roomID, dataType) -} - -// AddInviteEvent stores a new invite event for a user. -// If the invite was successfully stored this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatasource) AddInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, -) (types.StreamPosition, error) { - return d.invites.insertInviteEvent(ctx, inviteEvent) -} - -// RetireInviteEvent removes an old invite event from the database. -// Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatasource) RetireInviteEvent( - ctx context.Context, inviteEventID string, -) error { - // TODO: Record that invite has been retired in a stream so that we can - // notify the user in an incremental sync. - err := d.invites.deleteInviteEvent(ctx, inviteEventID) - return err -} - -func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { - d.eduCache.SetTimeoutCallback(fn) -} - -// AddTypingUser adds a typing user to the typing cache. -// Returns the newly calculated sync position for typing notifications. -func (d *SyncServerDatasource) AddTypingUser( - userID, roomID string, expireTime *time.Time, -) types.StreamPosition { - return types.StreamPosition(d.eduCache.AddTypingUser(userID, roomID, expireTime)) -} - -// RemoveTypingUser removes a typing user from the typing cache. -// Returns the newly calculated sync position for typing notifications. -func (d *SyncServerDatasource) RemoveTypingUser( - userID, roomID string, -) types.StreamPosition { - return types.StreamPosition(d.eduCache.RemoveUser(userID, roomID)) -} - -func (d *SyncServerDatasource) addInvitesToResponse( - ctx context.Context, txn *sql.Tx, - userID string, - fromPos, toPos types.StreamPosition, - res *types.Response, -) error { - invites, err := d.invites.selectInviteEventsInRange( - ctx, txn, userID, fromPos, toPos, - ) - if err != nil { - return err - } - for roomID, inviteEvent := range invites { - ir := types.NewInviteResponse() - ir.InviteState.Events = gomatrixserverlib.ToClientEvents( - []gomatrixserverlib.Event{inviteEvent.Event}, gomatrixserverlib.FormatSync, - ) - // TODO: add the invite state from the invite event. - res.Rooms.Invite[roomID] = *ir - } - return nil -} - -// Retrieve the backward topology position, i.e. the position of the -// oldest event in the room's topology. -func (d *SyncServerDatasource) getBackwardTopologyPos( - ctx context.Context, - events []types.StreamEvent, -) (pos types.StreamPosition) { - if len(events) > 0 { - pos, _ = d.topology.selectPositionInTopology(ctx, events[0].EventID()) - } - if pos-1 <= 0 { - pos = types.StreamPosition(1) - } else { - pos = pos - 1 - } - return -} - -// addRoomDeltaToResponse adds a room state delta to a sync response -func (d *SyncServerDatasource) addRoomDeltaToResponse( - ctx context.Context, - device *authtypes.Device, - txn *sql.Tx, - fromPos, toPos types.StreamPosition, - delta stateDelta, - numRecentEventsPerRoom int, - res *types.Response, -) error { - endPos := toPos - if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave { - // make sure we don't leak recent events after the leave event. - // TODO: History visibility makes this somewhat complex to handle correctly. For example: - // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). - // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave - // in a single /sync request - // This is all "okay" assuming history_visibility == "shared" which it is by default. - endPos = delta.membershipPos - } - recentStreamEvents, err := d.events.selectRecentEvents( - ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos), - numRecentEventsPerRoom, true, true, - ) - if err != nil { - return err - } - recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) - delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back - backwardTopologyPos := d.getBackwardTopologyPos(ctx, recentStreamEvents) - - switch delta.membership { - case gomatrixserverlib.Join: - jr := types.NewJoinResponse() - - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, - ).String() - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true - jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Join[delta.roomID] = *jr - case gomatrixserverlib.Leave: - fallthrough // transitions to leave are the same as ban - case gomatrixserverlib.Ban: - // TODO: recentEvents may contain events that this user is not allowed to see because they are - // no longer in the room. - lr := types.NewLeaveResponse() - lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, - ).String() - lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true - lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Leave[delta.roomID] = *lr - } - - return nil -} - -// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. -// Returns a map of room ID to list of events. -func (d *SyncServerDatasource) fetchStateEvents( - ctx context.Context, txn *sql.Tx, - roomIDToEventIDSet map[string]map[string]bool, - eventIDToEvent map[string]types.StreamEvent, -) (map[string][]types.StreamEvent, error) { - stateBetween := make(map[string][]types.StreamEvent) - missingEvents := make(map[string][]string) - for roomID, ids := range roomIDToEventIDSet { - events := stateBetween[roomID] - for id, need := range ids { - if !need { - continue // deleted state - } - e, ok := eventIDToEvent[id] - if ok { - events = append(events, e) - } else { - m := missingEvents[roomID] - m = append(m, id) - missingEvents[roomID] = m - } - } - stateBetween[roomID] = events - } - - if len(missingEvents) > 0 { - // This happens when add_state_ids has an event ID which is not in the provided range. - // We need to explicitly fetch them. - allMissingEventIDs := []string{} - for _, missingEvIDs := range missingEvents { - allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) - } - evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs) - if err != nil { - return nil, err - } - // we know we got them all otherwise an error would've been returned, so just loop the events - for _, ev := range evs { - roomID := ev.RoomID() - stateBetween[roomID] = append(stateBetween[roomID], ev) - } - } - return stateBetween, nil -} - -func (d *SyncServerDatasource) fetchMissingStateEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]types.StreamEvent, error) { - // Fetch from the events table first so we pick up the stream ID for the - // event. - events, err := d.events.selectEvents(ctx, txn, eventIDs) - if err != nil { - return nil, err - } - - have := map[string]bool{} - for _, event := range events { - have[event.EventID()] = true - } - var missing []string - for _, eventID := range eventIDs { - if !have[eventID] { - missing = append(missing, eventID) - } - } - if len(missing) == 0 { - return events, nil - } - - // If they are missing from the events table then they should be state - // events that we received from outside the main event stream. - // These should be in the room state table. - stateEvents, err := d.roomstate.selectEventsWithEventIDs(ctx, txn, missing) - - if err != nil { - return nil, err - } - if len(stateEvents) != len(missing) { - return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing)) - } - events = append(events, stateEvents...) - return events, nil -} - -// getStateDeltas returns the state deltas between fromPos and toPos, -// exclusive of oldPos, inclusive of newPos, for the rooms in which -// the user has new membership events. -// A list of joined room IDs is also returned in case the caller needs it. -func (d *SyncServerDatasource) getStateDeltas( - ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos types.StreamPosition, userID string, - stateFilter *gomatrixserverlib.StateFilter, -) ([]stateDelta, []string, error) { - // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 - // - Get membership list changes for this user in this sync response - // - For each room which has membership list changes: - // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO). - // If it is, then we need to send the full room state down (and 'limited' is always true). - // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. - // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. - // - Get all CURRENTLY joined rooms, and add them to 'joined' block. - var deltas []stateDelta - - // get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilter) - if err != nil { - return nil, nil, err - } - state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) - if err != nil { - return nil, nil, err - } - - for roomID, stateStreamEvents := range state { - for _, ev := range stateStreamEvents { - // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event. - // We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this, - // dupe join events will result in the entire room state coming down to the client again. This is added in - // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to - // the timeline. - if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { - if membership == gomatrixserverlib.Join { - // send full room state down instead of a delta - var s []types.StreamEvent - s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) - if err != nil { - return nil, nil, err - } - state[roomID] = s - continue // we'll add this room in when we do joined rooms - } - - deltas = append(deltas, stateDelta{ - membership: membership, - membershipPos: ev.StreamPosition, - stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), - roomID: roomID, - }) - break - } - } - } - - // Add in currently joined rooms - joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return nil, nil, err - } - for _, joinedRoomID := range joinedRoomIDs { - deltas = append(deltas, stateDelta{ - membership: gomatrixserverlib.Join, - stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), - roomID: joinedRoomID, - }) - } - - return deltas, joinedRoomIDs, nil -} - -// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync -// requests with full_state=true. -// Fetches full state for all joined rooms and uses selectStateInRange to get -// updates for other rooms. -func (d *SyncServerDatasource) getStateDeltasForFullStateSync( - ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos types.StreamPosition, userID string, - stateFilter *gomatrixserverlib.StateFilter, -) ([]stateDelta, []string, error) { - joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return nil, nil, err - } - - // Use a reasonable initial capacity - deltas := make([]stateDelta, 0, len(joinedRoomIDs)) - - // Add full states for all joined rooms - for _, joinedRoomID := range joinedRoomIDs { - s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) - if stateErr != nil { - return nil, nil, stateErr - } - deltas = append(deltas, stateDelta{ - membership: gomatrixserverlib.Join, - stateEvents: d.StreamEventsToEvents(device, s), - roomID: joinedRoomID, - }) - } - - // Get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilter) - if err != nil { - return nil, nil, err - } - state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) - if err != nil { - return nil, nil, err - } - - for roomID, stateStreamEvents := range state { - for _, ev := range stateStreamEvents { - if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { - if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. - deltas = append(deltas, stateDelta{ - membership: membership, - membershipPos: ev.StreamPosition, - stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), - roomID: roomID, - }) - } - - break - } - } - } - - return deltas, joinedRoomIDs, nil -} - -func (d *SyncServerDatasource) currentStateStreamEventsForRoom( - ctx context.Context, txn *sql.Tx, roomID string, - stateFilter *gomatrixserverlib.StateFilter, -) ([]types.StreamEvent, error) { - allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilter) - if err != nil { - return nil, err - } - s := make([]types.StreamEvent, len(allState)) - for i := 0; i < len(s); i++ { - s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0} - } - return s, nil -} - -// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and -// matches the streamevent.transactionID device then the transaction ID gets -// added to the unsigned section of the output event. -func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent { - out := make([]gomatrixserverlib.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[i].HeaderedEvent - if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { - err := out[i].SetUnsignedField( - "transaction_id", in[i].TransactionID.TransactionID, - ) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - } - } - } - } - return out -} - -// There may be some overlap where events in stateEvents are already in recentEvents, so filter -// them out so we don't include them twice in the /sync response. They should be in recentEvents -// only, so clients get to the correct state once they have rolled forward. -func removeDuplicates(stateEvents, recentEvents []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { - for _, recentEv := range recentEvents { - if recentEv.StateKey() == nil { - continue // not a state event - } - // TODO: This is a linear scan over all the current state events in this room. This will - // be slow for big rooms. We should instead sort the state events by event ID (ORDER BY) - // then do a binary search to find matching events, similar to what roomserver does. - for j := 0; j < len(stateEvents); j++ { - if stateEvents[j].EventID() == recentEv.EventID() { - // overwrite the element to remove with the last element then pop the last element. - // This is orders of magnitude faster than re-slicing, but doesn't preserve ordering - // (we don't care about the order of stateEvents) - stateEvents[j] = stateEvents[len(stateEvents)-1] - stateEvents = stateEvents[:len(stateEvents)-1] - break // there shouldn't be multiple events with the same event ID - } - } - } - return stateEvents -} - -// getMembershipFromEvent returns the value of content.membership iff the event is a state event -// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. -func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string { - if ev.Type() == "m.room.member" && ev.StateKeyEquals(userID) { - membership, err := ev.Membership() - if err != nil { - return "" - } - return membership - } - return "" -} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go new file mode 100644 index 000000000..74ae3eabd --- /dev/null +++ b/syncapi/storage/shared/syncserver.go @@ -0,0 +1,1207 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "time" + + userapi "github.com/matrix-org/dendrite/userapi/api" + + "github.com/matrix-org/dendrite/eduserver/cache" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite +// For now this contains the shared functions +type Database struct { + DB *sql.DB + Invites tables.Invites + AccountData tables.AccountData + OutputEvents tables.Events + Topology tables.Topology + CurrentRoomState tables.CurrentRoomState + BackwardExtremities tables.BackwardsExtremities + SendToDevice tables.SendToDevice + SendToDeviceWriter *sqlutil.TransactionWriter + EDUCache *cache.EDUCache +} + +// Events lookups a list of event by their event ID. +// Returns a list of events matching the requested IDs found in the database. +// If an event is not found in the database then it will be omitted from the list. +// Returns an error if there was a problem talking with the database. +// Does not include any transaction IDs in the returned events. +func (d *Database) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { + streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs) + if err != nil { + return nil, err + } + + // We don't include a device here as we only include transaction IDs in + // incremental syncs. + return d.StreamEventsToEvents(nil, streamEvents), nil +} + +// GetEventsInStreamingRange retrieves all of the events on a given ordering using the +// given extremities and limit. +func (d *Database) GetEventsInStreamingRange( + ctx context.Context, + from, to *types.StreamingToken, + roomID string, limit int, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { + r := types.Range{ + From: from.PDUPosition(), + To: to.PDUPosition(), + Backwards: backwardOrdering, + } + if backwardOrdering { + // When using backward ordering, we want the most recent events first. + if events, err = d.OutputEvents.SelectRecentEvents( + ctx, nil, roomID, r, limit, false, false, + ); err != nil { + return + } + } else { + // When using forward ordering, we want the least recent events first. + if events, err = d.OutputEvents.SelectEarlyEvents( + ctx, nil, roomID, r, limit, + ); err != nil { + return + } + } + return events, err +} + +func (d *Database) AddTypingUser( + userID, roomID string, expireTime *time.Time, +) types.StreamPosition { + return types.StreamPosition(d.EDUCache.AddTypingUser(userID, roomID, expireTime)) +} + +func (d *Database) RemoveTypingUser( + userID, roomID string, +) types.StreamPosition { + return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID)) +} + +func (d *Database) AddSendToDevice() types.StreamPosition { + return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage()) +} + +func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { + d.EDUCache.SetTimeoutCallback(fn) +} + +func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { + return d.CurrentRoomState.SelectJoinedUsers(ctx) +} + +func (d *Database) GetStateEvent( + ctx context.Context, roomID, evType, stateKey string, +) (*gomatrixserverlib.HeaderedEvent, error) { + return d.CurrentRoomState.SelectStateEvent(ctx, roomID, evType, stateKey) +} + +func (d *Database) GetStateEventsForRoom( + ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, +) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter) + return err + }) + return +} + +func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { + var maxID int64 + var err error + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + maxID, err = d.OutputEvents.SelectMaxEventID(ctx, txn) + if err != nil { + return err + } + var maxAccountDataID int64 + maxAccountDataID, err = d.AccountData.SelectMaxAccountDataID(ctx, txn) + if err != nil { + return err + } + if maxAccountDataID > maxID { + maxID = maxAccountDataID + } + var maxInviteID int64 + maxInviteID, err = d.Invites.SelectMaxInviteID(ctx, txn) + if err != nil { + return err + } + if maxInviteID > maxID { + maxID = maxInviteID + } + return nil + }) + return types.StreamPosition(maxID), err +} + +// AddInviteEvent stores a new invite event for a user. +// If the invite was successfully stored this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. +func (d *Database) AddInviteEvent( + ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, +) (sp types.StreamPosition, err error) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) + return err + }) + return +} + +// RetireInviteEvent removes an old invite event from the database. +// Returns an error if there was a problem communicating with the database. +func (d *Database) RetireInviteEvent( + ctx context.Context, inviteEventID string, +) error { + // TODO: Record that invite has been retired in a stream so that we can + // notify the user in an incremental sync. + err := d.Invites.DeleteInviteEvent(ctx, inviteEventID) + return err +} + +// GetAccountDataInRange returns all account data for a given user inserted or +// updated between two given positions +// Returns a map following the format data[roomID] = []dataTypes +// If no data is retrieved, returns an empty map +// If there was an issue with the retrieval, returns an error +func (d *Database) GetAccountDataInRange( + ctx context.Context, userID string, r types.Range, + accountDataFilterPart *gomatrixserverlib.EventFilter, +) (map[string][]string, error) { + return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart) +} + +// UpsertAccountData keeps track of new or updated account data, by saving the type +// of the new/updated data, and the user ID and room ID the data is related to (empty) +// room ID means the data isn't specific to any room) +// If no data with the given type, user ID and room ID exists in the database, +// creates a new row, else update the existing one +// Returns an error if there was an issue with the upsert +func (d *Database) UpsertAccountData( + ctx context.Context, userID, roomID, dataType string, +) (sp types.StreamPosition, err error) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) + return err + }) + return +} + +func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent { + out := make([]gomatrixserverlib.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[i].HeaderedEvent + if device != nil && in[i].TransactionID != nil { + if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { + err := out[i].SetUnsignedField( + "transaction_id", in[i].TransactionID.TransactionID, + ) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + } + } + } + } + return out +} + +// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of +// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table +// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. +func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { + if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { + return err + } + + // Check if we have all of the event's previous events. If an event is + // missing, add it to the room's backward extremities. + prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs()) + if err != nil { + return err + } + var found bool + for _, eID := range ev.PrevEventIDs() { + found = false + for _, prevEv := range prevEvents { + if eID == prevEv.EventID() { + found = true + } + } + + // If the event is missing, consider it a backward extremity. + if !found { + if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil { + return err + } + } + } + + return nil +} + +func (d *Database) WriteEvent( + ctx context.Context, + ev *gomatrixserverlib.HeaderedEvent, + addStateEvents []gomatrixserverlib.HeaderedEvent, + addStateEventIDs, removeStateEventIDs []string, + transactionID *api.TransactionID, excludeFromSync bool, +) (pduPosition types.StreamPosition, returnErr error) { + returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + var err error + pos, err := d.OutputEvents.InsertEvent( + ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, + ) + if err != nil { + return err + } + pduPosition = pos + + if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { + return err + } + + if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { + return err + } + + if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { + // Nothing to do, the event may have just been a message event. + return nil + } + + return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition) + }) + + return pduPosition, returnErr +} + +func (d *Database) updateRoomState( + ctx context.Context, txn *sql.Tx, + removedEventIDs []string, + addedEvents []gomatrixserverlib.HeaderedEvent, + pduPosition types.StreamPosition, +) error { + // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. + for _, eventID := range removedEventIDs { + if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { + return err + } + } + + for _, event := range addedEvents { + if event.StateKey() == nil { + // ignore non state events + continue + } + var membership *string + if event.Type() == "m.room.member" { + value, err := event.Membership() + if err != nil { + return err + } + membership = &value + } + + if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { + return err + } + } + + return nil +} + +func (d *Database) GetEventsInTopologicalRange( + ctx context.Context, + from, to *types.TopologyToken, + roomID string, limit int, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { + var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition + if backwardOrdering { + // Backward ordering means the 'from' token has a higher depth than the 'to' token + minDepth = to.Depth() + maxDepth = from.Depth() + // for cases where we have say 5 events with the same depth, the TopologyToken needs to + // know which of the 5 the client has seen. This is done by using the PDU position. + // Events with the same maxDepth but less than this PDU position will be returned. + maxStreamPosForMaxDepth = from.PDUPosition() + } else { + // Forward ordering means the 'from' token has a lower depth than the 'to' token. + minDepth = from.Depth() + maxDepth = to.Depth() + } + + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.Topology.SelectEventIDsInRange( + ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering, + ) + if err != nil { + return + } + + // Retrieve the events' contents using their IDs. + events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs) + return +} + +func (d *Database) SyncPosition(ctx context.Context) (tok types.StreamingToken, err error) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + pos, err := d.syncPositionTx(ctx, txn) + if err != nil { + return err + } + tok = pos + return nil + }) + return +} + +func (d *Database) BackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (backwardExtremities map[string][]string, err error) { + return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, roomID) +} + +func (d *Database) MaxTopologicalPosition( + ctx context.Context, roomID string, +) (types.TopologyToken, error) { + depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) + if err != nil { + return types.NewTopologyToken(0, 0), err + } + return types.NewTopologyToken(depth, streamPos), nil +} + +func (d *Database) EventPositionInTopology( + ctx context.Context, eventID string, +) (types.TopologyToken, error) { + depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID) + if err != nil { + return types.NewTopologyToken(0, 0), err + } + return types.NewTopologyToken(depth, stream), nil +} + +func (d *Database) syncPositionTx( + ctx context.Context, txn *sql.Tx, +) (sp types.StreamingToken, err error) { + + maxEventID, err := d.OutputEvents.SelectMaxEventID(ctx, txn) + if err != nil { + return sp, err + } + maxAccountDataID, err := d.AccountData.SelectMaxAccountDataID(ctx, txn) + if err != nil { + return sp, err + } + if maxAccountDataID > maxEventID { + maxEventID = maxAccountDataID + } + maxInviteID, err := d.Invites.SelectMaxInviteID(ctx, txn) + if err != nil { + return sp, err + } + if maxInviteID > maxEventID { + maxEventID = maxInviteID + } + sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition())) + return +} + +// addPDUDeltaToResponse adds all PDU deltas to a sync response. +// IDs of all rooms the user joined are returned so EDU deltas can be added for them. +func (d *Database) addPDUDeltaToResponse( + ctx context.Context, + device userapi.Device, + r types.Range, + numRecentEventsPerRoom int, + wantFullState bool, + res *types.Response, +) (joinedRoomIDs []string, err error) { + txn, err := d.DB.BeginTx(ctx, &txReadOnlySnapshot) + if err != nil { + return nil, err + } + var succeeded bool + defer func() { + txerr := sqlutil.EndTransaction(txn, &succeeded) + if err == nil && txerr != nil { + err = txerr + } + }() + + stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request + + // Work out which rooms to return in the response. This is done by getting not only the currently + // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions. + // This works out what the 'state' key should be for each room as well as which membership block + // to put the room into. + var deltas []stateDelta + if !wantFullState { + deltas, joinedRoomIDs, err = d.getStateDeltas( + ctx, &device, txn, r, device.UserID, &stateFilter, + ) + } else { + deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync( + ctx, &device, txn, r, device.UserID, &stateFilter, + ) + } + if err != nil { + return nil, err + } + + for _, delta := range deltas { + err = d.addRoomDeltaToResponse(ctx, &device, txn, r, delta, numRecentEventsPerRoom, res) + if err != nil { + return nil, err + } + } + + // TODO: This should be done in getStateDeltas + if err = d.addInvitesToResponse(ctx, txn, device.UserID, r, res); err != nil { + return nil, err + } + + succeeded = true + return joinedRoomIDs, nil +} + +// addTypingDeltaToResponse adds all typing notifications to a sync response +// since the specified position. +func (d *Database) addTypingDeltaToResponse( + since types.StreamingToken, + joinedRoomIDs []string, + res *types.Response, +) error { + var jr types.JoinResponse + var ok bool + var err error + for _, roomID := range joinedRoomIDs { + if typingUsers, updated := d.EDUCache.GetTypingUsersIfUpdatedAfter( + roomID, int64(since.EDUPosition()), + ); updated { + ev := gomatrixserverlib.ClientEvent{ + Type: gomatrixserverlib.MTyping, + } + ev.Content, err = json.Marshal(map[string]interface{}{ + "user_ids": typingUsers, + }) + if err != nil { + return err + } + + if jr, ok = res.Rooms.Join[roomID]; !ok { + jr = *types.NewJoinResponse() + } + jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) + res.Rooms.Join[roomID] = jr + } + } + return nil +} + +// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if +// the positions of that type are not equal in fromPos and toPos. +func (d *Database) addEDUDeltaToResponse( + fromPos, toPos types.StreamingToken, + joinedRoomIDs []string, + res *types.Response, +) (err error) { + + if fromPos.EDUPosition() != toPos.EDUPosition() { + err = d.addTypingDeltaToResponse( + fromPos, joinedRoomIDs, res, + ) + } + + return +} + +func (d *Database) IncrementalSync( + ctx context.Context, res *types.Response, + device userapi.Device, + fromPos, toPos types.StreamingToken, + numRecentEventsPerRoom int, + wantFullState bool, +) (*types.Response, error) { + nextBatchPos := fromPos.WithUpdates(toPos) + res.NextBatch = nextBatchPos.String() + + var joinedRoomIDs []string + var err error + if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState { + r := types.Range{ + From: fromPos.PDUPosition(), + To: toPos.PDUPosition(), + } + joinedRoomIDs, err = d.addPDUDeltaToResponse( + ctx, device, r, numRecentEventsPerRoom, wantFullState, res, + ) + } else { + joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership( + ctx, nil, device.UserID, gomatrixserverlib.Join, + ) + } + if err != nil { + return nil, err + } + + err = d.addEDUDeltaToResponse( + fromPos, toPos, joinedRoomIDs, res, + ) + if err != nil { + return nil, err + } + + return res, nil +} + +// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed +// to it. It returns toPos and joinedRoomIDs for use of adding EDUs. +// nolint:nakedret +func (d *Database) getResponseWithPDUsForCompleteSync( + ctx context.Context, res *types.Response, + userID string, + numRecentEventsPerRoom int, +) ( + toPos types.StreamingToken, + joinedRoomIDs []string, + err error, +) { + // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have + // a consistent view of the database throughout. This includes extracting the sync position. + // This does have the unfortunate side-effect that all the matrixy logic resides in this function, + // but it's better to not hide the fact that this is being done in a transaction. + txn, err := d.DB.BeginTx(ctx, &txReadOnlySnapshot) + if err != nil { + return + } + var succeeded bool + defer func() { + txerr := sqlutil.EndTransaction(txn, &succeeded) + if err == nil && txerr != nil { + err = txerr + } + }() + + // Get the current sync position which we will base the sync response on. + toPos, err = d.syncPositionTx(ctx, txn) + if err != nil { + return + } + r := types.Range{ + From: 0, + To: toPos.PDUPosition(), + } + + res.NextBatch = toPos.String() + + // Extract room state and recent events for all rooms the user is joined to. + joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) + if err != nil { + return + } + + stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request + + // Build up a /sync response. Add joined rooms. + for _, roomID := range joinedRoomIDs { + var stateEvents []gomatrixserverlib.HeaderedEvent + stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, &stateFilter) + if err != nil { + return + } + // TODO: When filters are added, we may need to call this multiple times to get enough events. + // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 + var recentStreamEvents []types.StreamEvent + recentStreamEvents, err = d.OutputEvents.SelectRecentEvents( + ctx, txn, roomID, r, numRecentEventsPerRoom, true, true, + ) + if err != nil { + return + } + + // Retrieve the backward topology position, i.e. the position of the + // oldest event in the room's topology. + var prevBatchStr string + if len(recentStreamEvents) > 0 { + var backwardTopologyPos, backwardStreamPos types.StreamPosition + backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) + if err != nil { + return + } + prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardStreamPos) + prevBatch.Decrement() + prevBatchStr = prevBatch.String() + } + + // We don't include a device here as we don't need to send down + // transaction IDs for complete syncs + recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) + stateEvents = removeDuplicates(stateEvents, recentEvents) + jr := types.NewJoinResponse() + jr.Timeline.PrevBatch = prevBatchStr + jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) + jr.Timeline.Limited = true + jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) + res.Rooms.Join[roomID] = *jr + } + + if err = d.addInvitesToResponse(ctx, txn, userID, r, res); err != nil { + return + } + + succeeded = true + return //res, toPos, joinedRoomIDs, err +} + +func (d *Database) CompleteSync( + ctx context.Context, res *types.Response, + device userapi.Device, numRecentEventsPerRoom int, +) (*types.Response, error) { + toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( + ctx, res, device.UserID, numRecentEventsPerRoom, + ) + if err != nil { + return nil, err + } + + // Use a zero value SyncPosition for fromPos so all EDU states are added. + err = d.addEDUDeltaToResponse( + types.NewStreamToken(0, 0), toPos, joinedRoomIDs, res, + ) + if err != nil { + return nil, err + } + + return res, nil +} + +var txReadOnlySnapshot = sql.TxOptions{ + // Set the isolation level so that we see a snapshot of the database. + // In PostgreSQL repeatable read transactions will see a snapshot taken + // at the first query, and since the transaction is read-only it can't + // run into any serialisation errors. + // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, +} + +func (d *Database) addInvitesToResponse( + ctx context.Context, txn *sql.Tx, + userID string, + r types.Range, + res *types.Response, +) error { + invites, err := d.Invites.SelectInviteEventsInRange( + ctx, txn, userID, r, + ) + if err != nil { + return err + } + for roomID, inviteEvent := range invites { + ir := types.NewInviteResponse(inviteEvent) + res.Rooms.Invite[roomID] = *ir + } + return nil +} + +// Retrieve the backward topology position, i.e. the position of the +// oldest event in the room's topology. +func (d *Database) getBackwardTopologyPos( + ctx context.Context, txn *sql.Tx, + events []types.StreamEvent, +) (types.TopologyToken, error) { + zeroToken := types.NewTopologyToken(0, 0) + if len(events) == 0 { + return zeroToken, nil + } + pos, spos, err := d.Topology.SelectPositionInTopology(ctx, txn, events[0].EventID()) + if err != nil { + return zeroToken, err + } + tok := types.NewTopologyToken(pos, spos) + tok.Decrement() + return tok, nil +} + +// addRoomDeltaToResponse adds a room state delta to a sync response +func (d *Database) addRoomDeltaToResponse( + ctx context.Context, + device *userapi.Device, + txn *sql.Tx, + r types.Range, + delta stateDelta, + numRecentEventsPerRoom int, + res *types.Response, +) error { + if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave { + // make sure we don't leak recent events after the leave event. + // TODO: History visibility makes this somewhat complex to handle correctly. For example: + // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). + // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave + // in a single /sync request + // This is all "okay" assuming history_visibility == "shared" which it is by default. + r.To = delta.membershipPos + } + recentStreamEvents, err := d.OutputEvents.SelectRecentEvents( + ctx, txn, delta.roomID, r, + numRecentEventsPerRoom, true, true, + ) + if err != nil { + return err + } + recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) + delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back + prevBatch, err := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents) + if err != nil { + return err + } + + switch delta.membership { + case gomatrixserverlib.Join: + jr := types.NewJoinResponse() + + jr.Timeline.PrevBatch = prevBatch.String() + jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) + jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true + jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) + res.Rooms.Join[delta.roomID] = *jr + case gomatrixserverlib.Leave: + fallthrough // transitions to leave are the same as ban + case gomatrixserverlib.Ban: + // TODO: recentEvents may contain events that this user is not allowed to see because they are + // no longer in the room. + lr := types.NewLeaveResponse() + lr.Timeline.PrevBatch = prevBatch.String() + lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) + lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true + lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) + res.Rooms.Leave[delta.roomID] = *lr + } + + return nil +} + +// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. +// Returns a map of room ID to list of events. +func (d *Database) fetchStateEvents( + ctx context.Context, txn *sql.Tx, + roomIDToEventIDSet map[string]map[string]bool, + eventIDToEvent map[string]types.StreamEvent, +) (map[string][]types.StreamEvent, error) { + stateBetween := make(map[string][]types.StreamEvent) + missingEvents := make(map[string][]string) + for roomID, ids := range roomIDToEventIDSet { + events := stateBetween[roomID] + for id, need := range ids { + if !need { + continue // deleted state + } + e, ok := eventIDToEvent[id] + if ok { + events = append(events, e) + } else { + m := missingEvents[roomID] + m = append(m, id) + missingEvents[roomID] = m + } + } + stateBetween[roomID] = events + } + + if len(missingEvents) > 0 { + // This happens when add_state_ids has an event ID which is not in the provided range. + // We need to explicitly fetch them. + allMissingEventIDs := []string{} + for _, missingEvIDs := range missingEvents { + allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) + } + evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs) + if err != nil { + return nil, err + } + // we know we got them all otherwise an error would've been returned, so just loop the events + for _, ev := range evs { + roomID := ev.RoomID() + stateBetween[roomID] = append(stateBetween[roomID], ev) + } + } + return stateBetween, nil +} + +func (d *Database) fetchMissingStateEvents( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]types.StreamEvent, error) { + // Fetch from the events table first so we pick up the stream ID for the + // event. + events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs) + if err != nil { + return nil, err + } + + have := map[string]bool{} + for _, event := range events { + have[event.EventID()] = true + } + var missing []string + for _, eventID := range eventIDs { + if !have[eventID] { + missing = append(missing, eventID) + } + } + if len(missing) == 0 { + return events, nil + } + + // If they are missing from the events table then they should be state + // events that we received from outside the main event stream. + // These should be in the room state table. + stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing) + + if err != nil { + return nil, err + } + if len(stateEvents) != len(missing) { + return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing)) + } + events = append(events, stateEvents...) + return events, nil +} + +// getStateDeltas returns the state deltas between fromPos and toPos, +// exclusive of oldPos, inclusive of newPos, for the rooms in which +// the user has new membership events. +// A list of joined room IDs is also returned in case the caller needs it. +func (d *Database) getStateDeltas( + ctx context.Context, device *userapi.Device, txn *sql.Tx, + r types.Range, userID string, + stateFilter *gomatrixserverlib.StateFilter, +) ([]stateDelta, []string, error) { + // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 + // - Get membership list changes for this user in this sync response + // - For each room which has membership list changes: + // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO). + // If it is, then we need to send the full room state down (and 'limited' is always true). + // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. + // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. + // - Get all CURRENTLY joined rooms, and add them to 'joined' block. + var deltas []stateDelta + + // get all the state events ever between these two positions + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) + if err != nil { + return nil, nil, err + } + state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) + if err != nil { + return nil, nil, err + } + + for roomID, stateStreamEvents := range state { + for _, ev := range stateStreamEvents { + // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event. + // We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this, + // dupe join events will result in the entire room state coming down to the client again. This is added in + // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to + // the timeline. + if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { + if membership == gomatrixserverlib.Join { + // send full room state down instead of a delta + var s []types.StreamEvent + s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) + if err != nil { + return nil, nil, err + } + state[roomID] = s + continue // we'll add this room in when we do joined rooms + } + + deltas = append(deltas, stateDelta{ + membership: membership, + membershipPos: ev.StreamPosition, + stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + roomID: roomID, + }) + break + } + } + } + + // Add in currently joined rooms + joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) + if err != nil { + return nil, nil, err + } + for _, joinedRoomID := range joinedRoomIDs { + deltas = append(deltas, stateDelta{ + membership: gomatrixserverlib.Join, + stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), + roomID: joinedRoomID, + }) + } + + return deltas, joinedRoomIDs, nil +} + +// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync +// requests with full_state=true. +// Fetches full state for all joined rooms and uses selectStateInRange to get +// updates for other rooms. +func (d *Database) getStateDeltasForFullStateSync( + ctx context.Context, device *userapi.Device, txn *sql.Tx, + r types.Range, userID string, + stateFilter *gomatrixserverlib.StateFilter, +) ([]stateDelta, []string, error) { + joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) + if err != nil { + return nil, nil, err + } + + // Use a reasonable initial capacity + deltas := make([]stateDelta, 0, len(joinedRoomIDs)) + + // Add full states for all joined rooms + for _, joinedRoomID := range joinedRoomIDs { + s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) + if stateErr != nil { + return nil, nil, stateErr + } + deltas = append(deltas, stateDelta{ + membership: gomatrixserverlib.Join, + stateEvents: d.StreamEventsToEvents(device, s), + roomID: joinedRoomID, + }) + } + + // Get all the state events ever between these two positions + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) + if err != nil { + return nil, nil, err + } + state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) + if err != nil { + return nil, nil, err + } + + for roomID, stateStreamEvents := range state { + for _, ev := range stateStreamEvents { + if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { + if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. + deltas = append(deltas, stateDelta{ + membership: membership, + membershipPos: ev.StreamPosition, + stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + roomID: roomID, + }) + } + + break + } + } + } + + return deltas, joinedRoomIDs, nil +} + +func (d *Database) currentStateStreamEventsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, + stateFilter *gomatrixserverlib.StateFilter, +) ([]types.StreamEvent, error) { + allState, err := d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter) + if err != nil { + return nil, err + } + s := make([]types.StreamEvent, len(allState)) + for i := 0; i < len(s); i++ { + s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0} + } + return s, nil +} + +func (d *Database) SendToDeviceUpdatesWaiting( + ctx context.Context, userID, deviceID string, +) (bool, error) { + count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID) + if err != nil { + return false, err + } + return count > 0, nil +} + +func (d *Database) AddSendToDeviceEvent( + ctx context.Context, txn *sql.Tx, + userID, deviceID, content string, +) error { + return d.SendToDevice.InsertSendToDeviceMessage( + ctx, txn, userID, deviceID, content, + ) +} + +func (d *Database) StoreNewSendForDeviceMessage( + ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, +) (types.StreamPosition, error) { + j, err := json.Marshal(event) + if err != nil { + return streamPos, err + } + // Delegate the database write task to the SendToDeviceWriter. It'll guarantee + // that we don't lock the table for writes in more than one place. + err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + return d.AddSendToDeviceEvent( + ctx, txn, userID, deviceID, string(j), + ) + }) + if err != nil { + return streamPos, err + } + return streamPos, nil +} + +func (d *Database) SendToDeviceUpdatesForSync( + ctx context.Context, + userID, deviceID string, + token types.StreamingToken, +) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) { + // First of all, get our send-to-device updates for this user. + events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) + if err != nil { + return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) + } + + // If there's nothing to do then stop here. + if len(events) == 0 { + return nil, nil, nil, nil + } + + // Work out whether we need to update any of the database entries. + toReturn := []types.SendToDeviceEvent{} + toUpdate := []types.SendToDeviceNID{} + toDelete := []types.SendToDeviceNID{} + for _, event := range events { + if event.SentByToken == nil { + // If the event has no sent-by token yet then we haven't attempted to send + // it. Record the current requested sync token in the database. + toUpdate = append(toUpdate, event.ID) + toReturn = append(toReturn, event) + event.SentByToken = &token + } else if token.IsAfter(*event.SentByToken) { + // The event had a sync token, therefore we've sent it before. The current + // sync token is now after the stored one so we can assume that the client + // successfully completed the previous sync (it would re-request it otherwise) + // so we can remove the entry from the database. + toDelete = append(toDelete, event.ID) + } else { + // It looks like the sync is being re-requested, maybe it timed out or + // failed. Re-send any that should have been acknowledged by now. + toReturn = append(toReturn, event) + } + } + + return toReturn, toUpdate, toDelete, nil +} + +func (d *Database) CleanSendToDeviceUpdates( + ctx context.Context, + toUpdate, toDelete []types.SendToDeviceNID, + token types.StreamingToken, +) (err error) { + if len(toUpdate) == 0 && len(toDelete) == 0 { + return nil + } + // If we need to write to the database then we'll ask the SendToDeviceWriter to + // do that for us. It'll guarantee that we don't lock the table for writes in + // more than one place. + err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + // Delete any send-to-device messages marked for deletion. + if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { + return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) + } + + // Now update any outstanding send-to-device messages with the new sync token. + if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil { + return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err) + } + + return nil + }) + return +} + +// There may be some overlap where events in stateEvents are already in recentEvents, so filter +// them out so we don't include them twice in the /sync response. They should be in recentEvents +// only, so clients get to the correct state once they have rolled forward. +func removeDuplicates(stateEvents, recentEvents []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { + for _, recentEv := range recentEvents { + if recentEv.StateKey() == nil { + continue // not a state event + } + // TODO: This is a linear scan over all the current state events in this room. This will + // be slow for big rooms. We should instead sort the state events by event ID (ORDER BY) + // then do a binary search to find matching events, similar to what roomserver does. + for j := 0; j < len(stateEvents); j++ { + if stateEvents[j].EventID() == recentEv.EventID() { + // overwrite the element to remove with the last element then pop the last element. + // This is orders of magnitude faster than re-slicing, but doesn't preserve ordering + // (we don't care about the order of stateEvents) + stateEvents[j] = stateEvents[len(stateEvents)-1] + stateEvents = stateEvents[:len(stateEvents)-1] + break // there shouldn't be multiple events with the same event ID + } + } + } + return stateEvents +} + +// getMembershipFromEvent returns the value of content.membership iff the event is a state event +// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. +func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string { + if ev.Type() == "m.room.member" && ev.StateKeyEquals(userID) { + membership, err := ev.Membership() + if err != nil { + return "" + } + return membership + } + return "" +} + +type stateDelta struct { + roomID string + stateEvents []gomatrixserverlib.HeaderedEvent + membership string + // The PDU stream position of the latest membership event for this user, if applicable. + // Can be 0 if there is no membership event in this delta. + membershipPos types.StreamPosition +} diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 3dbf961b4..ae5caa4e5 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -19,8 +19,8 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" - + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -55,25 +55,27 @@ type accountDataStatements struct { selectAccountDataInRangeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { - s.streamIDStatements = streamID - _, err = db.Exec(accountDataSchema) +func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { + s := &accountDataStatements{ + streamIDStatements: streamID, + } + _, err := db.Exec(accountDataSchema) if err != nil { - return + return nil, err } if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return + return nil, err } if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { - return + return nil, err } if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { @@ -85,26 +87,19 @@ func (s *accountDataStatements) insertAccountData( return } -func (s *accountDataStatements) selectAccountDataInRange( +func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, userID string, - oldPos, newPos types.StreamPosition, + r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter, ) (data map[string][]string, err error) { data = make(map[string][]string) - // If both positions are the same, it means that the data was saved after the - // latest room event. In that case, we need to decrement the old position as - // it would prevent the SQL request from returning anything. - if oldPos == newPos { - oldPos-- - } - - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos) + rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High()) if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") var entries int @@ -146,7 +141,7 @@ func (s *accountDataStatements) selectAccountDataInRange( return data, nil } -func (s *accountDataStatements) selectMaxAccountDataID( +func (s *accountDataStatements) SelectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 diff --git a/syncapi/storage/sqlite3/backward_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go similarity index 59% rename from syncapi/storage/sqlite3/backward_extremities_table.go rename to syncapi/storage/sqlite3/backwards_extremities_table.go index 3d8cb91fc..e16e54a6f 100644 --- a/syncapi/storage/sqlite3/backward_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -18,28 +18,10 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/syncapi/storage/tables" ) -// The purpose of this table is to keep track of backwards extremities for a room. -// Backwards extremities are the earliest (DAG-wise) known events which we have -// the entire event JSON. These event IDs are used in federation requests to fetch -// even earlier events. -// -// We persist the previous event IDs as well, one per row, so when we do fetch even -// earlier events we can simply delete rows which referenced it. Consider the graph: -// A -// | Event C has 1 prev_event ID: A. -// B C -// |___| Event D has 2 prev_event IDs: B and C. -// | -// D -// The earliest known event we have is D, so this table has 2 rows. -// A backfill request gives us C but not B. We delete rows where prev_event=C. This -// still means that D is a backwards extremity as we do not have event B. However, event -// C is *also* a backwards extremity at this point as we do not have event A. Later, -// when we fetch event B, we delete rows where prev_event=B. This then removes D as -// a backwards extremity because there are no more rows with event_id=B. const backwardExtremitiesSchema = ` -- Stores output room events received from the roomserver. CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( @@ -49,7 +31,6 @@ CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( event_id TEXT NOT NULL, -- The prev_events for the last known event. This is used to update extremities. prev_event_id TEXT NOT NULL, - PRIMARY KEY(room_id, event_id, prev_event_id) ); ` @@ -60,7 +41,7 @@ const insertBackwardExtremitySQL = "" + " ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING" const selectBackwardExtremitiesForRoomSQL = "" + - "SELECT DISTINCT event_id FROM syncapi_backward_extremities WHERE room_id = $1" + "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1" const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" @@ -71,52 +52,54 @@ type backwardExtremitiesStatements struct { deleteBackwardExtremityStmt *sql.Stmt } -func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(backwardExtremitiesSchema) +func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { + s := &backwardExtremitiesStatements{} + _, err := db.Exec(backwardExtremitiesSchema) if err != nil { - return + return nil, err } if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return + return nil, err } if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return + return nil, err } if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *backwardExtremitiesStatements) insertsBackwardExtremity( +func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) return } -func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom( +func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( ctx context.Context, roomID string, -) (eventIDs []string, err error) { +) (bwExtrems map[string][]string, err error) { rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed") + bwExtrems = make(map[string][]string) for rows.Next() { var eID string - if err = rows.Scan(&eID); err != nil { + var prevEventID string + if err = rows.Scan(&eID, &prevEventID); err != nil { return } - - eventIDs = append(eventIDs, eID) + bwExtrems[eID] = append(bwExtrems[eID], prevEventID) } - return eventIDs, rows.Err() + return bwExtrems, rows.Err() } -func (s *backwardExtremitiesStatements) deleteBackwardExtremity( +func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 9fafdbede..85f212ad8 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -21,7 +21,9 @@ import ( "encoding/json" "strings" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -91,42 +93,44 @@ type currentRoomStateStatements struct { selectStateEventStmt *sql.Stmt } -func (s *currentRoomStateStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { - s.streamIDStatements = streamID - _, err = db.Exec(currentRoomStateSchema) +func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { + s := ¤tRoomStateStatements{ + streamIDStatements: streamID, + } + _, err := db.Exec(currentRoomStateSchema) if err != nil { - return + return nil, err } if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return + return nil, err } if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return + return nil, err } if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return + return nil, err } if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { - return + return nil, err } if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return + return nil, err } if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return + return nil, err } - return + return s, nil } // JoinedMemberLists returns a map of room ID to a list of joined user IDs. -func (s *currentRoomStateStatements) selectJoinedUsers( +func (s *currentRoomStateStatements) SelectJoinedUsers( ctx context.Context, ) (map[string][]string, error) { rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") result := make(map[string][]string) for rows.Next() { @@ -143,18 +147,18 @@ func (s *currentRoomStateStatements) selectJoinedUsers( } // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. -func (s *currentRoomStateStatements) selectRoomIDsWithMembership( +func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( ctx context.Context, txn *sql.Tx, userID string, membership string, // nolint: unparam ) ([]string, error) { - stmt := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) rows, err := stmt.QueryContext(ctx, userID, membership) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed") var result []string for rows.Next() { @@ -168,11 +172,11 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership( } // CurrentState returns all the current state events for the given room. -func (s *currentRoomStateStatements) selectCurrentState( +func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, ) ([]gomatrixserverlib.HeaderedEvent, error) { - stmt := common.TxStmt(txn, s.selectCurrentStateStmt) + stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) rows, err := stmt.QueryContext(ctx, roomID, nil, // FIXME: pq.StringArray(stateFilterPart.Senders), nil, // FIXME: pq.StringArray(stateFilterPart.NotSenders), @@ -184,20 +188,20 @@ func (s *currentRoomStateStatements) selectCurrentState( if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectCurrentState: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectCurrentState: rows.close() failed") return rowsToEvents(rows) } -func (s *currentRoomStateStatements) deleteRoomStateByEventID( +func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - stmt := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) _, err := stmt.ExecContext(ctx, eventID) return err } -func (s *currentRoomStateStatements) upsertRoomState( +func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, ) error { @@ -215,7 +219,7 @@ func (s *currentRoomStateStatements) upsertRoomState( } // upsert state event - stmt := common.TxStmt(txn, s.upsertRoomStateStmt) + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) _, err = stmt.ExecContext( ctx, event.RoomID(), @@ -231,19 +235,19 @@ func (s *currentRoomStateStatements) upsertRoomState( return err } -func (s *currentRoomStateStatements) selectEventsWithEventIDs( +func (s *currentRoomStateStatements) SelectEventsWithEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } - query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) + query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) rows, err := txn.QueryContext(ctx, query, iEventIDs...) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") return rowsToStreamEvents(rows) } @@ -264,7 +268,7 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) { return result, nil } -func (s *currentRoomStateStatements) selectStateEvent( +func (s *currentRoomStateStatements) SelectStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { stmt := s.selectStateEventStmt diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 22efeaeb0..bb58e3456 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -20,7 +20,9 @@ import ( "database/sql" "encoding/json" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -62,30 +64,37 @@ type inviteEventsStatements struct { selectMaxInviteIDStmt *sql.Stmt } -func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { - s.streamIDStatements = streamID - _, err = db.Exec(inviteEventsSchema) +func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { + s := &inviteEventsStatements{ + streamIDStatements: streamID, + } + _, err := db.Exec(inviteEventsSchema) + if err != nil { + return nil, err + } + if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { + return nil, err + } + if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { + return nil, err + } + if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { + return nil, err + } + if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *inviteEventsStatements) InsertInviteEvent( + ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, +) (streamPos types.StreamPosition, err error) { + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) if err != nil { return } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return - } - return -} -func (s *inviteEventsStatements) insertInviteEvent( - ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, streamPos types.StreamPosition, -) (err error) { var headeredJSON []byte headeredJSON, err = json.Marshal(inviteEvent) if err != nil { @@ -103,7 +112,7 @@ func (s *inviteEventsStatements) insertInviteEvent( return } -func (s *inviteEventsStatements) deleteInviteEvent( +func (s *inviteEventsStatements) DeleteInviteEvent( ctx context.Context, inviteEventID string, ) error { _, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID) @@ -112,15 +121,15 @@ func (s *inviteEventsStatements) deleteInviteEvent( // selectInviteEventsInRange returns a map of room ID to invite event for the // active invites for the target user ID in the supplied range. -func (s *inviteEventsStatements) selectInviteEventsInRange( - ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition, +func (s *inviteEventsStatements) SelectInviteEventsInRange( + ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, ) (map[string]gomatrixserverlib.HeaderedEvent, error) { - stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) - rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos) + stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) + rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") result := map[string]gomatrixserverlib.HeaderedEvent{} for rows.Next() { var ( @@ -141,11 +150,11 @@ func (s *inviteEventsStatements) selectInviteEventsInRange( return result, nil } -func (s *inviteEventsStatements) selectMaxInviteID( +func (s *inviteEventsStatements) SelectMaxInviteID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := common.TxStmt(txn, s.selectMaxInviteIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 08299f64b..367ab3c9a 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -21,10 +21,12 @@ import ( "encoding/json" "sort" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -109,47 +111,49 @@ type outputRoomEventsStatements struct { selectStateInRangeStmt *sql.Stmt } -func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { - s.streamIDStatements = streamID - _, err = db.Exec(outputRoomEventsSchema) +func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { + s := &outputRoomEventsStatements{ + streamIDStatements: streamID, + } + _, err := db.Exec(outputRoomEventsSchema) if err != nil { - return + return nil, err } if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { - return + return nil, err } if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil { - return + return nil, err } if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil { - return + return nil, err } if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil { - return + return nil, err } if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil { - return + return nil, err } if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil { - return + return nil, err } if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { - return + return nil, err } - return + return s, nil } // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. -func (s *outputRoomEventsStatements) selectStateInRange( - ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, +func (s *outputRoomEventsStatements) SelectStateInRange( + ctx context.Context, txn *sql.Tx, r types.Range, stateFilterPart *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { - stmt := common.TxStmt(txn, s.selectStateInRangeStmt) + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( - ctx, oldPos, newPos, + ctx, r.Low(), r.High(), /*pq.StringArray(stateFilterPart.Senders), pq.StringArray(stateFilterPart.NotSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), @@ -192,8 +196,8 @@ func (s *outputRoomEventsStatements) selectStateInRange( // since it'll just mark the event as not being needed. if len(addIDs) < len(delIDs) { log.WithFields(log.Fields{ - "since": oldPos, - "current": newPos, + "since": r.From, + "current": r.To, "adds": addIDsJSON, "dels": delIDsJSON, }).Warn("StateBetween: ignoring deleted state") @@ -229,11 +233,11 @@ func (s *outputRoomEventsStatements) selectStateInRange( // MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied, // then this function should only ever be used at startup, as it will race with inserting events if it is // done afterwards. If there are no inserted events, 0 is returned. -func (s *outputRoomEventsStatements) selectMaxEventID( +func (s *outputRoomEventsStatements) SelectMaxEventID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := common.TxStmt(txn, s.selectMaxEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 @@ -243,7 +247,7 @@ func (s *outputRoomEventsStatements) selectMaxEventID( // InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position // of the inserted event. -func (s *outputRoomEventsStatements) insertEvent( +func (s *outputRoomEventsStatements) InsertEvent( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool, @@ -283,7 +287,7 @@ func (s *outputRoomEventsStatements) insertEvent( return } - insertStmt := common.TxStmt(txn, s.insertEventStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) _, err = insertStmt.ExecContext( ctx, streamPos, @@ -303,26 +307,23 @@ func (s *outputRoomEventsStatements) insertEvent( return } -// selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'. -// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude -// from sync. -func (s *outputRoomEventsStatements) selectRecentEvents( +func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, fromPos, toPos types.StreamPosition, limit int, + roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool, ) ([]types.StreamEvent, error) { var stmt *sql.Stmt if onlySyncEvents { - stmt = common.TxStmt(txn, s.selectRecentEventsForSyncStmt) + stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) } else { - stmt = common.TxStmt(txn, s.selectRecentEventsStmt) + stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } - rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") events, err := rowsToStreamEvents(rows) if err != nil { return nil, err @@ -338,18 +339,16 @@ func (s *outputRoomEventsStatements) selectRecentEvents( return events, nil } -// selectEarlyEvents returns the earliest events in the given room, starting -// from a given position, up to a maximum of 'limit'. -func (s *outputRoomEventsStatements) selectEarlyEvents( +func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, - roomID string, fromPos, toPos types.StreamPosition, limit int, + roomID string, r types.Range, limit int, ) ([]types.StreamEvent, error) { - stmt := common.TxStmt(txn, s.selectEarlyEventsStmt) - rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) + rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) if err != nil { return nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") events, err := rowsToStreamEvents(rows) if err != nil { return nil, err @@ -365,11 +364,11 @@ func (s *outputRoomEventsStatements) selectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. -func (s *outputRoomEventsStatements) selectEvents( +func (s *outputRoomEventsStatements) SelectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { var returnEvents []types.StreamEvent - stmt := common.TxStmt(txn, s.selectEventsStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) for _, eventID := range eventIDs { rows, err := stmt.QueryContext(ctx, eventID) if err != nil { @@ -378,7 +377,7 @@ func (s *outputRoomEventsStatements) selectEvents( if streamEvents, err := rowsToStreamEvents(rows); err == nil { returnEvents = append(returnEvents, streamEvents...) } - common.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") + internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") } return returnEvents, nil } diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index a2944c2f9..811dfa4f3 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,7 +18,8 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -27,41 +28,42 @@ const outputRoomEventsTopologySchema = ` -- Stores output room events received from the roomserver. CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( event_id TEXT PRIMARY KEY, - topological_position BIGINT NOT NULL, + topological_position BIGINT NOT NULL, + stream_position BIGINT NOT NULL, room_id TEXT NOT NULL, - UNIQUE(topological_position, room_id) + UNIQUE(topological_position, room_id, stream_position) ); -- The topological order will be used in events selection and ordering --- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id); +-- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id); ` const insertEventInTopologySQL = "" + - "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" + - " VALUES ($1, $2, $3)" + - " ON CONFLICT (topological_position, room_id) DO UPDATE SET event_id = $1" + "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT DO NOTHING" const selectEventIDsInRangeASCSQL = "" + "SELECT event_id FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + - " ORDER BY topological_position ASC LIMIT $4" + " WHERE room_id = $1 AND (" + + "(topological_position > $2 AND topological_position < $3) OR" + + "(topological_position = $4 AND stream_position <= $5)" + + ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" const selectEventIDsInRangeDESCSQL = "" + - "SELECT event_id FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + - " ORDER BY topological_position DESC LIMIT $4" + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND (" + + "(topological_position > $2 AND topological_position < $3) OR" + + "(topological_position = $4 AND stream_position <= $5)" + + ") ORDER BY topological_position DESC, stream_position DESC LIMIT $6" const selectPositionInTopologySQL = "" + - "SELECT topological_position FROM syncapi_output_room_events_topology" + + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" const selectMaxPositionInTopologySQL = "" + - "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1" - -const selectEventIDsFromPositionSQL = "" + - "SELECT event_id FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 AND topological_position = $2" + "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 ORDER BY stream_position DESC" type outputRoomEventsTopologyStatements struct { insertEventInTopologyStmt *sql.Stmt @@ -69,66 +71,60 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt - selectEventIDsFromPositionStmt *sql.Stmt } -func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(outputRoomEventsTopologySchema) +func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { + s := &outputRoomEventsTopologyStatements{} + _, err := db.Exec(outputRoomEventsTopologySchema) if err != nil { - return + return nil, err } if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return + return nil, err } if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return + return nil, err } if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return + return nil, err } if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return + return nil, err } if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return + return nil, err } - if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil { - return - } - return + return s, nil } // insertEventInTopology inserts the given event in the room's topology, based // on the event's depth. -func (s *outputRoomEventsTopologyStatements) insertEventInTopology( - ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, +func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( + ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { - stmt := common.TxStmt(txn, s.insertEventInTopologyStmt) + stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) _, err = stmt.ExecContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), + ctx, event.EventID(), event.Depth(), event.RoomID(), pos, ) return } -// selectEventIDsInRange selects the IDs of events which positions are within a -// given range in a given room's topological order. -// Returns an empty slice if no events match the given range. -func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( +func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( ctx context.Context, txn *sql.Tx, roomID string, - fromPos, toPos types.StreamPosition, + minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool, ) (eventIDs []string, err error) { // Decide on the selection's order according to whether chronological order // is requested or not. var stmt *sql.Stmt if chronologicalOrder { - stmt = common.TxStmt(txn, s.selectEventIDsInRangeASCStmt) + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt) } else { - stmt = common.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) } // Query the event IDs. - rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit) if err == sql.ErrNoRows { // If no event matched the request, return an empty slice. return []string{}, nil @@ -150,43 +146,18 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( // selectPositionInTopology returns the position of a given event in the // topology of the room it belongs to. -func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( +func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( ctx context.Context, txn *sql.Tx, eventID string, -) (pos types.StreamPosition, err error) { - stmt := common.TxStmt(txn, s.selectPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, eventID).Scan(&pos) +) (pos types.StreamPosition, spos types.StreamPosition, err error) { + stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt) + err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) return } -func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( +func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, err error) { - stmt := common.TxStmt(txn, s.selectMaxPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, roomID).Scan(&pos) - return -} - -// selectEventIDsFromPosition returns the IDs of all events that have a given -// position in the topology of a given room. -func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition( - ctx context.Context, txn *sql.Tx, roomID string, pos types.StreamPosition, -) (eventIDs []string, err error) { - // Query the event IDs. - stmt := common.TxStmt(txn, s.selectEventIDsFromPositionStmt) - rows, err := stmt.QueryContext(ctx, roomID, pos) - if err == sql.ErrNoRows { - // If no event matched the request, return an empty slice. - return []string{}, nil - } else if err != nil { - return - } - // Return the IDs. - var eventID string - for rows.Next() { - if err = rows.Scan(&eventID); err != nil { - return - } - eventIDs = append(eventIDs, eventID) - } +) (pos types.StreamPosition, spos types.StreamPosition, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go new file mode 100644 index 000000000..42bd3c19a --- /dev/null +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -0,0 +1,173 @@ +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const sendToDeviceSchema = ` +-- Stores send-to-device messages. +CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( + -- The ID that uniquely identifies this message. + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The user ID to send the message to. + user_id TEXT NOT NULL, + -- The device ID to send the message to. + device_id TEXT NOT NULL, + -- The event content JSON. + content TEXT NOT NULL, + -- The token that was supplied to the /sync at the time that this + -- message was included in a sync response, or NULL if we haven't + -- included it in a /sync response yet. + sent_by_token TEXT +); +` + +const insertSendToDeviceMessageSQL = ` + INSERT INTO syncapi_send_to_device (user_id, device_id, content) + VALUES ($1, $2, $3) +` + +const countSendToDeviceMessagesSQL = ` + SELECT COUNT(*) + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 +` + +const selectSendToDeviceMessagesSQL = ` + SELECT id, user_id, device_id, content, sent_by_token + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 + ORDER BY id DESC +` + +const updateSentSendToDeviceMessagesSQL = ` + UPDATE syncapi_send_to_device SET sent_by_token = $1 + WHERE id IN ($2) +` + +const deleteSendToDeviceMessagesSQL = ` + DELETE FROM syncapi_send_to_device WHERE id IN ($1) +` + +type sendToDeviceStatements struct { + insertSendToDeviceMessageStmt *sql.Stmt + selectSendToDeviceMessagesStmt *sql.Stmt + countSendToDeviceMessagesStmt *sql.Stmt +} + +func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { + s := &sendToDeviceStatements{} + _, err := db.Exec(sendToDeviceSchema) + if err != nil { + return nil, err + } + if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { + return nil, err + } + if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *sendToDeviceStatements) InsertSendToDeviceMessage( + ctx context.Context, txn *sql.Tx, userID, deviceID, content string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return +} + +func (s *sendToDeviceStatements) CountSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (count int, err error) { + row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + if err = row.Scan(&count); err != nil { + return + } + return count, nil +} + +func (s *sendToDeviceStatements) SelectSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (events []types.SendToDeviceEvent, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") + + for rows.Next() { + var id types.SendToDeviceNID + var userID, deviceID, content string + var sentByToken *string + if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil { + return + } + event := types.SendToDeviceEvent{ + ID: id, + UserID: userID, + DeviceID: deviceID, + } + if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { + return + } + if sentByToken != nil { + if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil { + event.SentByToken = &token + } + } + events = append(events, event) + } + + return events, rows.Err() +} + +func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, +) (err error) { + query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", sqlutil.QueryVariadic(1+len(nids)), 1) + params := make([]interface{}, 1+len(nids)) + params[0] = token + for k, v := range nids { + params[k+1] = v + } + _, err = txn.ExecContext(ctx, query, params...) + return +} + +func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, +) (err error) { + query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + params := make([]interface{}, 1+len(nids)) + for k, v := range nids { + params[k] = v + } + _, err = txn.ExecContext(ctx, query, params...) + return +} diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index 260f7a95d..57abd9c44 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -4,7 +4,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -46,8 +46,8 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { } func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { - increaseStmt := common.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := common.TxStmt(txn, s.selectStreamIDStmt) + increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { return } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 29051cd06..51cdbe325 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -16,76 +16,39 @@ package sqlite3 import ( - "context" "database/sql" - "encoding/json" - "errors" - "fmt" - "net/url" - "time" - - "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/api" // Import the sqlite3 package _ "github.com/mattn/go-sqlite3" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/shared" ) -type stateDelta struct { - roomID string - stateEvents []gomatrixserverlib.HeaderedEvent - membership string - // The PDU stream position of the latest membership event for this user, if applicable. - // Can be 0 if there is no membership event in this delta. - membershipPos types.StreamPosition -} - // SyncServerDatasource represents a sync server datasource which manages // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { + shared.Database db *sql.DB - common.PartitionOffsetStatements - streamID streamIDStatements - accountData accountDataStatements - events outputRoomEventsStatements - roomstate currentRoomStateStatements - invites inviteEventsStatements - eduCache *cache.EDUCache - topology outputRoomEventsTopologyStatements - backwardExtremities backwardExtremitiesStatements + sqlutil.PartitionOffsetStatements + streamID streamIDStatements } -// NewSyncServerDatasource creates a new sync server database +// NewDatabase creates a new sync server database // nolint: gocyclo -func NewSyncServerDatasource(dataSourceName string) (*SyncServerDatasource, error) { +func NewDatabase(dataSourceName string) (*SyncServerDatasource, error) { var d SyncServerDatasource - uri, err := url.Parse(dataSourceName) + cs, err := sqlutil.ParseFileURI(dataSourceName) if err != nil { return nil, err } - var cs string - if uri.Opaque != "" { // file:filename.db - cs = uri.Opaque - } else if uri.Path != "" { // file:///path/to/filename.db - cs = uri.Path - } else { - return nil, errors.New("no filename or path in connect string") - } - if d.db, err = sqlutil.Open(common.SQLiteDriverName(), cs); err != nil { + if d.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } if err = d.prepare(); err != nil { return nil, err } - d.eduCache = cache.New() return &d, nil } @@ -96,1089 +59,45 @@ func (d *SyncServerDatasource) prepare() (err error) { if err = d.streamID.prepare(d.db); err != nil { return err } - if err = d.accountData.prepare(d.db, &d.streamID); err != nil { + accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) + if err != nil { return err } - if err = d.events.prepare(d.db, &d.streamID); err != nil { + events, err := NewSqliteEventsTable(d.db, &d.streamID) + if err != nil { return err } - if err := d.roomstate.prepare(d.db, &d.streamID); err != nil { + roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID) + if err != nil { return err } - if err := d.invites.prepare(d.db, &d.streamID); err != nil { + invites, err := NewSqliteInvitesTable(d.db, &d.streamID) + if err != nil { return err } - if err := d.topology.prepare(d.db); err != nil { + topology, err := NewSqliteTopologyTable(d.db) + if err != nil { return err } - if err := d.backwardExtremities.prepare(d.db); err != nil { + bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db) + if err != nil { return err } + sendToDevice, err := NewSqliteSendToDeviceTable(d.db) + if err != nil { + return err + } + d.Database = shared.Database{ + DB: d.db, + Invites: invites, + AccountData: accountData, + OutputEvents: events, + BackwardExtremities: bwExtrem, + CurrentRoomState: roomState, + Topology: topology, + SendToDevice: sendToDevice, + SendToDeviceWriter: sqlutil.NewTransactionWriter(), + EDUCache: cache.New(), + } return nil } - -// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. -func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { - return d.roomstate.selectJoinedUsers(ctx) -} - -// Events lookups a list of event by their event ID. -// Returns a list of events matching the requested IDs found in the database. -// If an event is not found in the database then it will be omitted from the list. -// Returns an error if there was a problem talking with the database. -// Does not include any transaction IDs in the returned events. -func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) - if err != nil { - return nil, err - } - - // We don't include a device here as we only include transaction IDs in - // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil -} - -// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of -// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table -// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. -func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { - if err := d.backwardExtremities.deleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { - return err - } - - // Check if we have all of the event's previous events. If an event is - // missing, add it to the room's backward extremities. - prevEvents, err := d.events.selectEvents(ctx, txn, ev.PrevEventIDs()) - if err != nil { - return err - } - var found bool - for _, eID := range ev.PrevEventIDs() { - found = false - for _, prevEv := range prevEvents { - if eID == prevEv.EventID() { - found = true - } - } - - // If the event is missing, consider it a backward extremity. - if !found { - if err = d.backwardExtremities.insertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil { - return err - } - } - } - - return nil -} - -// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races -// when generating the sync stream position for this event. Returns the sync stream position for the inserted event. -// Returns an error if there was a problem inserting this event. -func (d *SyncServerDatasource) WriteEvent( - ctx context.Context, - ev *gomatrixserverlib.HeaderedEvent, - addStateEvents []gomatrixserverlib.HeaderedEvent, - addStateEventIDs, removeStateEventIDs []string, - transactionID *api.TransactionID, excludeFromSync bool, -) (pduPosition types.StreamPosition, returnErr error) { - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - pos, err := d.events.insertEvent( - ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, - ) - if err != nil { - return err - } - pduPosition = pos - - if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil { - return err - } - - if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { - return err - } - - if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { - // Nothing to do, the event may have just been a message event. - return nil - } - - return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition) - }) - - return pduPosition, returnErr -} - -func (d *SyncServerDatasource) updateRoomState( - ctx context.Context, txn *sql.Tx, - removedEventIDs []string, - addedEvents []gomatrixserverlib.HeaderedEvent, - pduPosition types.StreamPosition, -) error { - // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. - for _, eventID := range removedEventIDs { - if err := d.roomstate.deleteRoomStateByEventID(ctx, txn, eventID); err != nil { - return err - } - } - - for _, event := range addedEvents { - if event.StateKey() == nil { - // ignore non state events - continue - } - var membership *string - if event.Type() == "m.room.member" { - value, err := event.Membership() - if err != nil { - return err - } - membership = &value - } - if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { - return err - } - } - - return nil -} - -// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key -// If no event could be found, returns nil -// If there was an issue during the retrieval, returns an error -func (d *SyncServerDatasource) GetStateEvent( - ctx context.Context, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { - return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) -} - -// GetStateEventsForRoom fetches the state events for a given room. -// Returns an empty slice if no state events could be found for this room. -// Returns an error if there was an issue with the retrieval. -func (d *SyncServerDatasource) GetStateEventsForRoom( - ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, -) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) - return err - }) - return -} - -// GetEventsInRange retrieves all of the events on a given ordering using the -// given extremities and limit. -func (d *SyncServerDatasource) GetEventsInRange( - ctx context.Context, - from, to *types.PaginationToken, - roomID string, limit int, - backwardOrdering bool, -) (events []types.StreamEvent, err error) { - // If the pagination token's type is types.PaginationTokenTypeTopology, the - // events must be retrieved from the rooms' topology table rather than the - // table contaning the syncapi server's whole stream of events. - if from.Type == types.PaginationTokenTypeTopology { - // Determine the backward and forward limit, i.e. the upper and lower - // limits to the selection in the room's topology, from the direction. - var backwardLimit, forwardLimit types.StreamPosition - if backwardOrdering { - // Backward ordering is antichronological (latest event to oldest - // one). - backwardLimit = to.PDUPosition - forwardLimit = from.PDUPosition - } else { - // Forward ordering is chronological (oldest event to latest one). - backwardLimit = from.PDUPosition - forwardLimit = to.PDUPosition - } - - // Select the event IDs from the defined range. - var eIDs []string - eIDs, err = d.topology.selectEventIDsInRange( - ctx, nil, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering, - ) - if err != nil { - return - } - - // Retrieve the events' contents using their IDs. - events, err = d.events.selectEvents(ctx, nil, eIDs) - return - } - - // If the pagination token's type is types.PaginationTokenTypeStream, the - // events must be retrieved from the table contaning the syncapi server's - // whole stream of events. - - if backwardOrdering { - // When using backward ordering, we want the most recent events first. - if events, err = d.events.selectRecentEvents( - ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, - ); err != nil { - return - } - } else { - // When using forward ordering, we want the least recent events first. - if events, err = d.events.selectEarlyEvents( - ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, - ); err != nil { - return - } - } - - return -} - -// SyncPosition returns the latest positions for syncing. -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.PaginationToken, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - tok, err = d.syncPositionTx(ctx, txn) - return err - }) - return -} - -// BackwardExtremitiesForRoom returns the event IDs of all of the backward -// extremities we know of for a given room. -func (d *SyncServerDatasource) BackwardExtremitiesForRoom( - ctx context.Context, roomID string, -) (backwardExtremities []string, err error) { - return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, roomID) -} - -// MaxTopologicalPosition returns the highest topological position for a given -// room. -func (d *SyncServerDatasource) MaxTopologicalPosition( - ctx context.Context, roomID string, -) (types.StreamPosition, error) { - return d.topology.selectMaxPositionInTopology(ctx, nil, roomID) -} - -// EventsAtTopologicalPosition returns all of the events matching a given -// position in the topology of a given room. -func (d *SyncServerDatasource) EventsAtTopologicalPosition( - ctx context.Context, roomID string, pos types.StreamPosition, -) ([]types.StreamEvent, error) { - eIDs, err := d.topology.selectEventIDsFromPosition(ctx, nil, roomID, pos) - if err != nil { - return nil, err - } - - return d.events.selectEvents(ctx, nil, eIDs) -} - -func (d *SyncServerDatasource) EventPositionInTopology( - ctx context.Context, eventID string, -) (types.StreamPosition, error) { - return d.topology.selectPositionInTopology(ctx, nil, eventID) -} - -// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. -func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (pos types.StreamPosition, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - pos, err = d.syncStreamPositionTx(ctx, txn) - return err - }) - return -} - -func (d *SyncServerDatasource) syncStreamPositionTx( - ctx context.Context, txn *sql.Tx, -) (types.StreamPosition, error) { - maxID, err := d.events.selectMaxEventID(ctx, txn) - if err != nil { - return 0, err - } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) - if err != nil { - return 0, err - } - if maxAccountDataID > maxID { - maxID = maxAccountDataID - } - maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) - if err != nil { - return 0, err - } - if maxInviteID > maxID { - maxID = maxInviteID - } - return types.StreamPosition(maxID), nil -} - -func (d *SyncServerDatasource) syncPositionTx( - ctx context.Context, txn *sql.Tx, -) (sp types.PaginationToken, err error) { - - maxEventID, err := d.events.selectMaxEventID(ctx, txn) - if err != nil { - return sp, err - } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) - if err != nil { - return sp, err - } - if maxAccountDataID > maxEventID { - maxEventID = maxAccountDataID - } - maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) - if err != nil { - return sp, err - } - if maxInviteID > maxEventID { - maxEventID = maxInviteID - } - sp.PDUPosition = types.StreamPosition(maxEventID) - sp.EDUTypingPosition = types.StreamPosition(d.eduCache.GetLatestSyncPosition()) - return -} - -// addPDUDeltaToResponse adds all PDU deltas to a sync response. -// IDs of all rooms the user joined are returned so EDU deltas can be added for them. -func (d *SyncServerDatasource) addPDUDeltaToResponse( - ctx context.Context, - device authtypes.Device, - fromPos, toPos types.StreamPosition, - numRecentEventsPerRoom int, - wantFullState bool, - res *types.Response, -) (joinedRoomIDs []string, err error) { - txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) - if err != nil { - return nil, err - } - var succeeded bool - defer func() { - txerr := common.EndTransaction(txn, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() - - stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request - - // Work out which rooms to return in the response. This is done by getting not only the currently - // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions. - // This works out what the 'state' key should be for each room as well as which membership block - // to put the room into. - var deltas []stateDelta - if !wantFullState { - deltas, joinedRoomIDs, err = d.getStateDeltas( - ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart, - ) - } else { - deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync( - ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart, - ) - } - if err != nil { - return nil, err - } - - for _, delta := range deltas { - err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) - if err != nil { - return nil, err - } - } - - // TODO: This should be done in getStateDeltas - if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil { - return nil, err - } - - succeeded = true - return joinedRoomIDs, nil -} - -// addTypingDeltaToResponse adds all typing notifications to a sync response -// since the specified position. -func (d *SyncServerDatasource) addTypingDeltaToResponse( - since types.PaginationToken, - joinedRoomIDs []string, - res *types.Response, -) error { - var jr types.JoinResponse - var ok bool - var err error - for _, roomID := range joinedRoomIDs { - if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter( - roomID, int64(since.EDUTypingPosition), - ); updated { - ev := gomatrixserverlib.ClientEvent{ - Type: gomatrixserverlib.MTyping, - } - ev.Content, err = json.Marshal(map[string]interface{}{ - "user_ids": typingUsers, - }) - if err != nil { - return err - } - - if jr, ok = res.Rooms.Join[roomID]; !ok { - jr = *types.NewJoinResponse() - } - jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) - res.Rooms.Join[roomID] = jr - } - } - return nil -} - -// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if -// the positions of that type are not equal in fromPos and toPos. -func (d *SyncServerDatasource) addEDUDeltaToResponse( - fromPos, toPos types.PaginationToken, - joinedRoomIDs []string, - res *types.Response, -) (err error) { - - if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { - err = d.addTypingDeltaToResponse( - fromPos, joinedRoomIDs, res, - ) - } - - return -} - -// IncrementalSync returns all the data needed in order to create an incremental -// sync response for the given user. Events returned will include any client -// transaction IDs associated with the given device. These transaction IDs come -// from when the device sent the event via an API that included a transaction -// ID. -func (d *SyncServerDatasource) IncrementalSync( - ctx context.Context, - device authtypes.Device, - fromPos, toPos types.PaginationToken, - numRecentEventsPerRoom int, - wantFullState bool, -) (*types.Response, error) { - nextBatchPos := fromPos.WithUpdates(toPos) - res := types.NewResponse(nextBatchPos) - - var joinedRoomIDs []string - var err error - if fromPos.PDUPosition != toPos.PDUPosition || wantFullState { - joinedRoomIDs, err = d.addPDUDeltaToResponse( - ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res, - ) - } else { - joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( - ctx, nil, device.UserID, gomatrixserverlib.Join, - ) - } - if err != nil { - return nil, err - } - - err = d.addEDUDeltaToResponse( - fromPos, toPos, joinedRoomIDs, res, - ) - if err != nil { - return nil, err - } - - return res, nil -} - -// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed -// to it. It returns toPos and joinedRoomIDs for use of adding EDUs. -func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( - ctx context.Context, - userID string, - numRecentEventsPerRoom int, -) ( - res *types.Response, - toPos types.PaginationToken, - joinedRoomIDs []string, - err error, -) { - // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have - // a consistent view of the database throughout. This includes extracting the sync position. - // This does have the unfortunate side-effect that all the matrixy logic resides in this function, - // but it's better to not hide the fact that this is being done in a transaction. - txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) - if err != nil { - return - } - var succeeded bool - defer func() { - txerr := common.EndTransaction(txn, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() - - // Get the current sync position which we will base the sync response on. - toPos, err = d.syncPositionTx(ctx, txn) - if err != nil { - return - } - - res = types.NewResponse(toPos) - - // Extract room state and recent events for all rooms the user is joined to. - joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return - } - - stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request - - // Build up a /sync response. Add joined rooms. - for _, roomID := range joinedRoomIDs { - var stateEvents []gomatrixserverlib.HeaderedEvent - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart) - if err != nil { - return - } - //fmt.Println("State events:", stateEvents) - // TODO: When filters are added, we may need to call this multiple times to get enough events. - // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - var recentStreamEvents []types.StreamEvent - recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, - numRecentEventsPerRoom, true, true, - ) - if err != nil { - return - } - //fmt.Println("Recent stream events:", recentStreamEvents) - - // Retrieve the backward topology position, i.e. the position of the - // oldest event in the room's topology. - var backwardTopologyPos types.StreamPosition - backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) - if backwardTopologyPos-1 <= 0 { - backwardTopologyPos = types.StreamPosition(1) - } else { - backwardTopologyPos-- - } - - // We don't include a device here as we don't need to send down - // transaction IDs for complete syncs - recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) - stateEvents = removeDuplicates(stateEvents, recentEvents) - jr := types.NewJoinResponse() - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, - ).String() - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = true - jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Join[roomID] = *jr - } - - if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { - return - } - - succeeded = true - return res, toPos, joinedRoomIDs, err -} - -// CompleteSync returns a complete /sync API response for the given user. -func (d *SyncServerDatasource) CompleteSync( - ctx context.Context, userID string, numRecentEventsPerRoom int, -) (*types.Response, error) { - res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( - ctx, userID, numRecentEventsPerRoom, - ) - if err != nil { - return nil, err - } - - // Use a zero value SyncPosition for fromPos so all EDU states are added. - err = d.addEDUDeltaToResponse( - types.PaginationToken{}, toPos, joinedRoomIDs, res, - ) - if err != nil { - return nil, err - } - - return res, nil -} - -var txReadOnlySnapshot = sql.TxOptions{ - // Set the isolation level so that we see a snapshot of the database. - // In PostgreSQL repeatable read transactions will see a snapshot taken - // at the first query, and since the transaction is read-only it can't - // run into any serialisation errors. - // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ - Isolation: sql.LevelRepeatableRead, - ReadOnly: true, -} - -// GetAccountDataInRange returns all account data for a given user inserted or -// updated between two given positions -// Returns a map following the format data[roomID] = []dataTypes -// If no data is retrieved, returns an empty map -// If there was an issue with the retrieval, returns an error -func (d *SyncServerDatasource) GetAccountDataInRange( - ctx context.Context, userID string, oldPos, newPos types.StreamPosition, - accountDataFilterPart *gomatrixserverlib.EventFilter, -) (map[string][]string, error) { - return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) -} - -// UpsertAccountData keeps track of new or updated account data, by saving the type -// of the new/updated data, and the user ID and room ID the data is related to (empty) -// room ID means the data isn't specific to any room) -// If no data with the given type, user ID and room ID exists in the database, -// creates a new row, else update the existing one -// Returns an error if there was an issue with the upsert -func (d *SyncServerDatasource) UpsertAccountData( - ctx context.Context, userID, roomID, dataType string, -) (sp types.StreamPosition, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - sp, err = d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType) - return err - }) - return -} - -// AddInviteEvent stores a new invite event for a user. -// If the invite was successfully stored this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatasource) AddInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, -) (streamPos types.StreamPosition, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - streamPos, err = d.streamID.nextStreamID(ctx, txn) - if err != nil { - return err - } - return d.invites.insertInviteEvent(ctx, txn, inviteEvent, streamPos) - }) - return -} - -// RetireInviteEvent removes an old invite event from the database. -// Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatasource) RetireInviteEvent( - ctx context.Context, inviteEventID string, -) error { - // TODO: Record that invite has been retired in a stream so that we can - // notify the user in an incremental sync. - err := d.invites.deleteInviteEvent(ctx, inviteEventID) - return err -} - -func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { - d.eduCache.SetTimeoutCallback(fn) -} - -// AddTypingUser adds a typing user to the typing cache. -// Returns the newly calculated sync position for typing notifications. -func (d *SyncServerDatasource) AddTypingUser( - userID, roomID string, expireTime *time.Time, -) types.StreamPosition { - return types.StreamPosition(d.eduCache.AddTypingUser(userID, roomID, expireTime)) -} - -// RemoveTypingUser removes a typing user from the typing cache. -// Returns the newly calculated sync position for typing notifications. -func (d *SyncServerDatasource) RemoveTypingUser( - userID, roomID string, -) types.StreamPosition { - return types.StreamPosition(d.eduCache.RemoveUser(userID, roomID)) -} - -func (d *SyncServerDatasource) addInvitesToResponse( - ctx context.Context, txn *sql.Tx, - userID string, - fromPos, toPos types.StreamPosition, - res *types.Response, -) error { - invites, err := d.invites.selectInviteEventsInRange( - ctx, txn, userID, fromPos, toPos, - ) - if err != nil { - return err - } - for roomID, inviteEvent := range invites { - ir := types.NewInviteResponse() - ir.InviteState.Events = gomatrixserverlib.HeaderedToClientEvents( - []gomatrixserverlib.HeaderedEvent{inviteEvent}, gomatrixserverlib.FormatSync, - ) - // TODO: add the invite state from the invite event. - res.Rooms.Invite[roomID] = *ir - } - return nil -} - -// Retrieve the backward topology position, i.e. the position of the -// oldest event in the room's topology. -func (d *SyncServerDatasource) getBackwardTopologyPos( - ctx context.Context, txn *sql.Tx, - events []types.StreamEvent, -) (pos types.StreamPosition) { - if len(events) > 0 { - pos, _ = d.topology.selectPositionInTopology(ctx, txn, events[0].EventID()) - } - if pos-1 <= 0 { - pos = types.StreamPosition(1) - } else { - pos = pos - 1 - } - return -} - -// addRoomDeltaToResponse adds a room state delta to a sync response -func (d *SyncServerDatasource) addRoomDeltaToResponse( - ctx context.Context, - device *authtypes.Device, - txn *sql.Tx, - fromPos, toPos types.StreamPosition, - delta stateDelta, - numRecentEventsPerRoom int, - res *types.Response, -) error { - endPos := toPos - if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave { - // make sure we don't leak recent events after the leave event. - // TODO: History visibility makes this somewhat complex to handle correctly. For example: - // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). - // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave - // in a single /sync request - // This is all "okay" assuming history_visibility == "shared" which it is by default. - endPos = delta.membershipPos - } - recentStreamEvents, err := d.events.selectRecentEvents( - ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos), - numRecentEventsPerRoom, true, true, - ) - if err != nil { - return err - } - recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) - delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) - backwardTopologyPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents) - - switch delta.membership { - case gomatrixserverlib.Join: - jr := types.NewJoinResponse() - - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, - ).String() - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true - jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Join[delta.roomID] = *jr - case gomatrixserverlib.Leave: - fallthrough // transitions to leave are the same as ban - case gomatrixserverlib.Ban: - // TODO: recentEvents may contain events that this user is not allowed to see because they are - // no longer in the room. - lr := types.NewLeaveResponse() - lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, - ).String() - lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true - lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Leave[delta.roomID] = *lr - } - - return nil -} - -// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. -// Returns a map of room ID to list of events. -func (d *SyncServerDatasource) fetchStateEvents( - ctx context.Context, txn *sql.Tx, - roomIDToEventIDSet map[string]map[string]bool, - eventIDToEvent map[string]types.StreamEvent, -) (map[string][]types.StreamEvent, error) { - stateBetween := make(map[string][]types.StreamEvent) - missingEvents := make(map[string][]string) - for roomID, ids := range roomIDToEventIDSet { - events := stateBetween[roomID] - for id, need := range ids { - if !need { - continue // deleted state - } - e, ok := eventIDToEvent[id] - if ok { - events = append(events, e) - } else { - m := missingEvents[roomID] - m = append(m, id) - missingEvents[roomID] = m - } - } - stateBetween[roomID] = events - } - - if len(missingEvents) > 0 { - // This happens when add_state_ids has an event ID which is not in the provided range. - // We need to explicitly fetch them. - allMissingEventIDs := []string{} - for _, missingEvIDs := range missingEvents { - allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) - } - evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs) - if err != nil { - return nil, err - } - // we know we got them all otherwise an error would've been returned, so just loop the events - for _, ev := range evs { - roomID := ev.RoomID() - stateBetween[roomID] = append(stateBetween[roomID], ev) - } - } - return stateBetween, nil -} - -func (d *SyncServerDatasource) fetchMissingStateEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]types.StreamEvent, error) { - // Fetch from the events table first so we pick up the stream ID for the - // event. - events, err := d.events.selectEvents(ctx, txn, eventIDs) - if err != nil { - return nil, err - } - - have := map[string]bool{} - for _, event := range events { - have[event.EventID()] = true - } - var missing []string - for _, eventID := range eventIDs { - if !have[eventID] { - missing = append(missing, eventID) - } - } - if len(missing) == 0 { - return events, nil - } - - // If they are missing from the events table then they should be state - // events that we received from outside the main event stream. - // These should be in the room state table. - stateEvents, err := d.roomstate.selectEventsWithEventIDs(ctx, txn, missing) - - if err != nil { - return nil, err - } - if len(stateEvents) != len(missing) { - return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing)) - } - events = append(events, stateEvents...) - return events, nil -} - -// getStateDeltas returns the state deltas between fromPos and toPos, -// exclusive of oldPos, inclusive of newPos, for the rooms in which -// the user has new membership events. -// A list of joined room IDs is also returned in case the caller needs it. -func (d *SyncServerDatasource) getStateDeltas( - ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos types.StreamPosition, userID string, - stateFilterPart *gomatrixserverlib.StateFilter, -) ([]stateDelta, []string, error) { - // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 - // - Get membership list changes for this user in this sync response - // - For each room which has membership list changes: - // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO). - // If it is, then we need to send the full room state down (and 'limited' is always true). - // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. - // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. - // - Get all CURRENTLY joined rooms, and add them to 'joined' block. - var deltas []stateDelta - - // get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart) - if err != nil { - return nil, nil, err - } - state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) - if err != nil { - return nil, nil, err - } - - for roomID, stateStreamEvents := range state { - for _, ev := range stateStreamEvents { - // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event. - // We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this, - // dupe join events will result in the entire room state coming down to the client again. This is added in - // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to - // the timeline. - if membership := getMembershipFromEvent(&ev.HeaderedEvent, userID); membership != "" { - if membership == gomatrixserverlib.Join { - // send full room state down instead of a delta - var s []types.StreamEvent - s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart) - if err != nil { - return nil, nil, err - } - state[roomID] = s - continue // we'll add this room in when we do joined rooms - } - - deltas = append(deltas, stateDelta{ - membership: membership, - membershipPos: ev.StreamPosition, - stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), - roomID: roomID, - }) - break - } - } - } - - // Add in currently joined rooms - joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return nil, nil, err - } - for _, joinedRoomID := range joinedRoomIDs { - deltas = append(deltas, stateDelta{ - membership: gomatrixserverlib.Join, - stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), - roomID: joinedRoomID, - }) - } - - return deltas, joinedRoomIDs, nil -} - -// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync -// requests with full_state=true. -// Fetches full state for all joined rooms and uses selectStateInRange to get -// updates for other rooms. -func (d *SyncServerDatasource) getStateDeltasForFullStateSync( - ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos types.StreamPosition, userID string, - stateFilterPart *gomatrixserverlib.StateFilter, -) ([]stateDelta, []string, error) { - joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return nil, nil, err - } - - // Use a reasonable initial capacity - deltas := make([]stateDelta, 0, len(joinedRoomIDs)) - - // Add full states for all joined rooms - for _, joinedRoomID := range joinedRoomIDs { - s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilterPart) - if stateErr != nil { - return nil, nil, stateErr - } - deltas = append(deltas, stateDelta{ - membership: gomatrixserverlib.Join, - stateEvents: d.StreamEventsToEvents(device, s), - roomID: joinedRoomID, - }) - } - - // Get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart) - if err != nil { - return nil, nil, err - } - state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) - if err != nil { - return nil, nil, err - } - - for roomID, stateStreamEvents := range state { - for _, ev := range stateStreamEvents { - if membership := getMembershipFromEvent(&ev.HeaderedEvent, userID); membership != "" { - if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. - deltas = append(deltas, stateDelta{ - membership: membership, - membershipPos: ev.StreamPosition, - stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), - roomID: roomID, - }) - } - - break - } - } - } - - return deltas, joinedRoomIDs, nil -} - -func (d *SyncServerDatasource) currentStateStreamEventsForRoom( - ctx context.Context, txn *sql.Tx, roomID string, - stateFilterPart *gomatrixserverlib.StateFilter, -) ([]types.StreamEvent, error) { - allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) - if err != nil { - return nil, err - } - s := make([]types.StreamEvent, len(allState)) - for i := 0; i < len(s); i++ { - s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0} - } - return s, nil -} - -// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and -// matches the streamevent.transactionID device then the transaction ID gets -// added to the unsigned section of the output event. -func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent { - out := make([]gomatrixserverlib.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[i].HeaderedEvent - if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { - err := out[i].SetUnsignedField( - "transaction_id", in[i].TransactionID.TransactionID, - ) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - } - } - } - } - return out -} - -// There may be some overlap where events in stateEvents are already in recentEvents, so filter -// them out so we don't include them twice in the /sync response. They should be in recentEvents -// only, so clients get to the correct state once they have rolled forward. -func removeDuplicates(stateEvents, recentEvents []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { - for _, recentEv := range recentEvents { - if recentEv.StateKey() == nil { - continue // not a state event - } - // TODO: This is a linear scan over all the current state events in this room. This will - // be slow for big rooms. We should instead sort the state events by event ID (ORDER BY) - // then do a binary search to find matching events, similar to what roomserver does. - for j := 0; j < len(stateEvents); j++ { - if stateEvents[j].EventID() == recentEv.EventID() { - // overwrite the element to remove with the last element then pop the last element. - // This is orders of magnitude faster than re-slicing, but doesn't preserve ordering - // (we don't care about the order of stateEvents) - stateEvents[j] = stateEvents[len(stateEvents)-1] - stateEvents = stateEvents[:len(stateEvents)-1] - break // there shouldn't be multiple events with the same event ID - } - } - } - return stateEvents -} - -// getMembershipFromEvent returns the value of content.membership iff the event is a state event -// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. -func getMembershipFromEvent(ev *gomatrixserverlib.HeaderedEvent, userID string) string { - if ev.Type() == "m.room.member" && ev.StateKeyEquals(userID) { - membership, err := ev.Membership() - if err != nil { - return "" - } - return membership - } - return "" -} diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index c56db0635..ea69da3bc 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -19,22 +19,23 @@ package storage import ( "net/url" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) -// NewPublicRoomsServerDatabase opens a database connection. -func NewSyncServerDatasource(dataSourceName string) (Database, error) { +// NewSyncServerDatasource opens a database connection. +func NewSyncServerDatasource(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.NewSyncServerDatasource(dataSourceName) + return postgres.NewDatabase(dataSourceName, dbProperties) } switch uri.Scheme { case "postgres": - return postgres.NewSyncServerDatasource(dataSourceName) + return postgres.NewDatabase(dataSourceName, dbProperties) case "file": - return sqlite3.NewSyncServerDatasource(dataSourceName) + return sqlite3.NewDatabase(dataSourceName) default: - return postgres.NewSyncServerDatasource(dataSourceName) + return postgres.NewDatabase(dataSourceName, dbProperties) } } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go new file mode 100644 index 000000000..85084facb --- /dev/null +++ b/syncapi/storage/storage_test.go @@ -0,0 +1,659 @@ +package storage_test + +import ( + "context" + "crypto/ed25519" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +var ( + ctx = context.Background() + emptyStateKey = "" + testOrigin = gomatrixserverlib.ServerName("hollow.knight") + testRoomID = fmt.Sprintf("!hallownest:%s", testOrigin) + testUserIDA = fmt.Sprintf("@hornet:%s", testOrigin) + testUserIDB = fmt.Sprintf("@paleking:%s", testOrigin) + testUserDeviceA = userapi.Device{ + UserID: testUserIDA, + ID: "device_id_A", + DisplayName: "Device A", + } + testRoomVersion = gomatrixserverlib.RoomVersionV4 + testKeyID = gomatrixserverlib.KeyID("ed25519:storage_test") + testPrivateKey = ed25519.NewKeyFromSeed([]byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + }) +) + +func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) gomatrixserverlib.HeaderedEvent { + b.RoomID = roomID + if prevs != nil { + prevIDs := make([]string, len(prevs)) + for i := range prevs { + prevIDs[i] = prevs[i].EventID() + } + b.PrevEvents = prevIDs + } + e, err := b.Build(time.Now(), testOrigin, testKeyID, testPrivateKey, testRoomVersion) + if err != nil { + t.Fatalf("failed to build event: %s", err) + } + return e.Headered(testRoomVersion) +} + +func MustCreateDatabase(t *testing.T) storage.Database { + db, err := sqlite3.NewDatabase("file::memory:") + if err != nil { + t.Fatalf("NewSyncServerDatasource returned %s", err) + } + return db +} + +// Create a list of events which include a create event, join event and some messages. +func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []gomatrixserverlib.HeaderedEvent, state []gomatrixserverlib.HeaderedEvent) { + var events []gomatrixserverlib.HeaderedEvent + events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, userA)), + Type: "m.room.create", + StateKey: &emptyStateKey, + Sender: userA, + Depth: int64(len(events) + 1), + })) + state = append(state, events[len(events)-1]) + events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"membership":"join"}`)), + Type: "m.room.member", + StateKey: &userA, + Sender: userA, + Depth: int64(len(events) + 1), + })) + state = append(state, events[len(events)-1]) + for i := 0; i < 10; i++ { + events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)), + Type: "m.room.message", + Sender: userA, + Depth: int64(len(events) + 1), + })) + } + events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"membership":"join"}`)), + Type: "m.room.member", + StateKey: &userB, + Sender: userB, + Depth: int64(len(events) + 1), + })) + state = append(state, events[len(events)-1]) + for i := 0; i < 10; i++ { + events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"body":"Message B %d"}`, i+1)), + Type: "m.room.message", + Sender: userB, + Depth: int64(len(events) + 1), + })) + } + + return events, state +} + +func MustWriteEvents(t *testing.T, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { + for _, ev := range events { + var addStateEvents []gomatrixserverlib.HeaderedEvent + var addStateEventIDs []string + var removeStateEventIDs []string + if ev.StateKey() != nil { + addStateEvents = append(addStateEvents, ev) + addStateEventIDs = append(addStateEventIDs, ev.EventID()) + } + pos, err := db.WriteEvent(ctx, &ev, addStateEvents, addStateEventIDs, removeStateEventIDs, nil, false) + if err != nil { + t.Fatalf("WriteEvent failed: %s", err) + } + fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth()) + positions = append(positions, pos) + } + return +} + +func TestWriteEvents(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + MustWriteEvents(t, db, events) +} + +// These tests assert basic functionality of the IncrementalSync and CompleteSync functions. +func TestSyncResponse(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, state := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + positions := MustWriteEvents(t, db, events) + latest, err := db.SyncPosition(ctx) + if err != nil { + t.Fatalf("failed to get SyncPosition: %s", err) + } + + testCases := []struct { + Name string + DoSync func() (*types.Response, error) + WantTimeline []gomatrixserverlib.HeaderedEvent + WantState []gomatrixserverlib.HeaderedEvent + }{ + // The purpose of this test is to make sure that incremental syncs are including up to the latest events. + // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. + // It makes sure the response includes the final event. + { + Name: "IncrementalSync penultimate", + DoSync: func() (*types.Response, error) { + from := types.NewStreamToken( // pretend we are at the penultimate event + positions[len(positions)-2], types.StreamPosition(0), + ) + res := types.NewResponse() + return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) + }, + WantTimeline: events[len(events)-1:], + }, + // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the + // number of returned events. This is critical for big rooms hence the test here. + { + Name: "IncrementalSync limited", + DoSync: func() (*types.Response, error) { + from := types.NewStreamToken( // pretend we are 10 events behind + positions[len(positions)-11], types.StreamPosition(0), + ) + res := types.NewResponse() + // limit is set to 5 + return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) + }, + // want the last 5 events, NOT the last 10. + WantTimeline: events[len(events)-5:], + }, + // The purpose of this test is to check that CompleteSync returns all the current state as well as + // honouring the `numRecentEventsPerRoom` value + { + Name: "CompleteSync limited", + DoSync: func() (*types.Response, error) { + res := types.NewResponse() + // limit set to 5 + return db.CompleteSync(ctx, res, testUserDeviceA, 5) + }, + // want the last 5 events + WantTimeline: events[len(events)-5:], + // want all state for the room + WantState: state, + }, + // The purpose of this test is to check that CompleteSync can return everything with a high enough + // `numRecentEventsPerRoom`. + { + Name: "CompleteSync", + DoSync: func() (*types.Response, error) { + res := types.NewResponse() + return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1) + }, + WantTimeline: events, + // We want no state at all as that field in /sync is the delta between the token (beginning of time) + // and the START of the timeline. + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(st *testing.T) { + res, err := tc.DoSync() + if err != nil { + st.Fatalf("failed to do sync: %s", err) + } + next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition()) + if res.NextBatch != next.String() { + st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) + } + roomRes, ok := res.Rooms.Join[testRoomID] + if !ok { + st.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res) + } + assertEventsEqual(st, "state for "+testRoomID, false, roomRes.State.Events, tc.WantState) + assertEventsEqual(st, "timeline for "+testRoomID, false, roomRes.Timeline.Events, tc.WantTimeline) + }) + } +} + +func TestGetEventsInRangeWithPrevBatch(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + positions := MustWriteEvents(t, db, events) + latest, err := db.SyncPosition(ctx) + if err != nil { + t.Fatalf("failed to get SyncPosition: %s", err) + } + from := types.NewStreamToken( + positions[len(positions)-2], types.StreamPosition(0), + ) + + res := types.NewResponse() + res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) + if err != nil { + t.Fatalf("failed to IncrementalSync with latest token") + } + roomRes, ok := res.Rooms.Join[testRoomID] + if !ok { + t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res) + } + // returns the last event "Message 10" + assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:])) + + prev := roomRes.Timeline.PrevBatch + if prev == "" { + t.Fatalf("IncrementalSync expected prev_batch token") + } + prevBatchToken, err := types.NewTopologyTokenFromString(prev) + if err != nil { + t.Fatalf("failed to NewTopologyTokenFromString : %s", err) + } + // backpaginate 5 messages starting at the latest position. + // head towards the beginning of time + to := types.NewTopologyToken(0, 0) + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true) + if err != nil { + t.Fatalf("GetEventsInRange returned an error: %s", err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1])) +} + +// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token. +func TestGetEventsInRangeWithStreamToken(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + MustWriteEvents(t, db, events) + latest, err := db.SyncPosition(ctx) + if err != nil { + t.Fatalf("failed to get SyncPosition: %s", err) + } + // head towards the beginning of time + to := types.NewStreamToken(0, 0) + + // backpaginate 5 messages starting at the latest position. + paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) + if err != nil { + t.Fatalf("GetEventsInRange returned an error: %s", err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) +} + +// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token +func TestGetEventsInRangeWithTopologyToken(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + MustWriteEvents(t, db, events) + from, err := db.MaxTopologicalPosition(ctx, testRoomID) + if err != nil { + t.Fatalf("failed to get MaxTopologicalPosition: %s", err) + } + // head towards the beginning of time + to := types.NewTopologyToken(0, 0) + + // backpaginate 5 messages starting at the latest position. + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) + if err != nil { + t.Fatalf("GetEventsInRange returned an error: %s", err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) +} + +// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. +// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent +// will appear FIRST when going backwards. This test creates a DAG like: +// .-----> Message ---. +// Create -> Membership --------> Message -------> Message +// `-----> Message ---` +// depth 1 2 3 4 +// +// With a total depth of 4. It tests that: +// - Backpagination over the whole fork should include all messages and not leave any out. +// - Backpagination from the middle of the fork should not return duplicates (things later than the token). +func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + + var events []gomatrixserverlib.HeaderedEvent + events = append(events, MustCreateEvent(t, testRoomID, nil, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, testUserIDA)), + Type: "m.room.create", + StateKey: &emptyStateKey, + Sender: testUserIDA, + Depth: int64(len(events) + 1), + })) + events = append(events, MustCreateEvent(t, testRoomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"membership":"join"}`)), + Type: "m.room.member", + StateKey: &testUserIDA, + Sender: testUserIDA, + Depth: int64(len(events) + 1), + })) + // fork the dag into three, same prev_events and depth + parent := []gomatrixserverlib.HeaderedEvent{events[len(events)-1]} + depth := int64(len(events) + 1) + for i := 0; i < 3; i++ { + events = append(events, MustCreateEvent(t, testRoomID, parent, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)), + Type: "m.room.message", + Sender: testUserIDA, + Depth: depth, + })) + } + // merge the fork, prev_events are all 3 messages, depth is increased by 1. + events = append(events, MustCreateEvent(t, testRoomID, events[len(events)-3:], &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"body":"Message merge"}`)), + Type: "m.room.message", + Sender: testUserIDA, + Depth: depth + 1, + })) + MustWriteEvents(t, db, events) + fromLatest, err := db.EventPositionInTopology(ctx, events[len(events)-1].EventID()) + if err != nil { + t.Fatalf("failed to get EventPositionInTopology: %s", err) + } + fromFork, err := db.EventPositionInTopology(ctx, events[len(events)-3].EventID()) // Message 2 + if err != nil { + t.Fatalf("failed to get EventPositionInTopology for event: %s", err) + } + // head towards the beginning of time + to := types.NewTopologyToken(0, 0) + + testCases := []struct { + Name string + From types.TopologyToken + Limit int + Wants []gomatrixserverlib.HeaderedEvent + }{ + { + Name: "Pagination over the whole fork", + From: fromLatest, + Limit: 5, + Wants: reversed(events[len(events)-5:]), + }, + { + Name: "Paginating to the middle of the fork", + From: fromLatest, + Limit: 2, + Wants: reversed(events[len(events)-2:]), + }, + { + Name: "Pagination FROM the middle of the fork", + From: fromFork, + Limit: 3, + Wants: reversed(events[len(events)-5 : len(events)-2]), + }, + } + + for _, tc := range testCases { + // backpaginate messages starting at the latest position. + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &tc.From, &to, testRoomID, tc.Limit, true) + if err != nil { + t.Fatalf("%s GetEventsInRange returned an error: %s", tc.Name, err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + assertEventsEqual(t, tc.Name, true, gots, tc.Wants) + } +} + +// The purpose of this test is to make sure that the query to pull out events is honouring the room ID correctly. +// It works by creating two rooms with the same events in them, then selecting events by topological range. +// Specifically, we know that events with the same depth but lower stream positions are selected, and it's possible +// that this check isn't using the room ID if the brackets are wrong in the SQL query. +func TestGetEventsInTopologicalRangeMultiRoom(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + + makeEvents := func(roomID string) (events []gomatrixserverlib.HeaderedEvent) { + events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, testUserIDA)), + Type: "m.room.create", + StateKey: &emptyStateKey, + Sender: testUserIDA, + Depth: int64(len(events) + 1), + })) + events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"membership":"join"}`)), + Type: "m.room.member", + StateKey: &testUserIDA, + Sender: testUserIDA, + Depth: int64(len(events) + 1), + })) + return + } + + roomA := "!room_a:" + string(testOrigin) + roomB := "!room_b:" + string(testOrigin) + eventsA := makeEvents(roomA) + eventsB := makeEvents(roomB) + MustWriteEvents(t, db, eventsA) + MustWriteEvents(t, db, eventsB) + from, err := db.MaxTopologicalPosition(ctx, roomB) + if err != nil { + t.Fatalf("failed to get MaxTopologicalPosition: %s", err) + } + // head towards the beginning of time + to := types.NewTopologyToken(0, 0) + + // Query using room B as room A was inserted first and hence A will have lower stream positions but identical depths, + // allowing this bug to surface. + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, roomB, 5, true) + if err != nil { + t.Fatalf("GetEventsInRange returned an error: %s", err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + assertEventsEqual(t, "", true, gots, reversed(eventsB)) +} + +// The purpose of this test is to make sure that events are returned in the right *order* when they have been inserted in a manner similar to +// how any kind of backfill operation will insert the events. This test inserts the SimpleRoom events in a manner similar to how backfill over +// federation would: +// - First inserts join event of test user C +// - Inserts chunks of history in strata e.g (25-30, 20-25, 15-20, 10-15, 5-10, 0-5). +// The test then does a backfill to ensure that the response is ordered correctly according to depth. +func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + + // "federation" join + userC := fmt.Sprintf("@radiance:%s", testOrigin) + joinEvent := MustCreateEvent(t, testRoomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"membership":"join"}`)), + Type: "m.room.member", + StateKey: &userC, + Sender: userC, + Depth: int64(len(events) + 1), + }) + MustWriteEvents(t, db, []gomatrixserverlib.HeaderedEvent{joinEvent}) + + // Sync will return this for the prev_batch + from := topologyTokenBefore(t, db, joinEvent.EventID()) + + // inject events in batches as if they were from backfill + // e.g [1,2,3,4,5,6] => [4,5,6] , [1,2,3] + chunkSize := 5 + for i := len(events); i >= 0; i -= chunkSize { + start := i - chunkSize + if start < 0 { + start = 0 + } + backfill := events[start:i] + MustWriteEvents(t, db, backfill) + } + + // head towards the beginning of time + to := types.NewTopologyToken(0, 0) + + // starting at `from`, backpaginate to the beginning of time, asserting as we go. + chunkSize = 3 + events = reversed(events) + for i := 0; i < len(events); i += chunkSize { + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, from, &to, testRoomID, chunkSize, true) + if err != nil { + t.Fatalf("GetEventsInRange returned an error: %s", err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + endi := i + chunkSize + if endi > len(events) { + endi = len(events) + } + assertEventsEqual(t, from.String(), true, gots, events[i:endi]) + from = topologyTokenBefore(t, db, paginatedEvents[len(paginatedEvents)-1].EventID()) + } +} + +func TestSendToDeviceBehaviour(t *testing.T) { + //t.Parallel() + db := MustCreateDatabase(t) + + // At this point there should be no messages. We haven't sent anything + // yet. + events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { + t.Fatal("first call should have no updates") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0)) + if err != nil { + return + } + + // Try sending a message. + streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{ + Sender: "bob", + Type: "m.type", + Content: json.RawMessage("{}"), + }) + if err != nil { + t.Fatal(err) + } + + // At this point we should get exactly one message. We're sending the sync position + // that we were given from the update and the send-to-device update will be updated + // in the database to reflect that this was the sync position we sent the message at. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + if err != nil { + t.Fatal(err) + } + if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 { + t.Fatal("second call should have one update") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + if err != nil { + return + } + + // At this point we should still have one message because we haven't progressed the + // sync position yet. This is equivalent to the client failing to /sync and retrying + // with the same position. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + if err != nil { + t.Fatal(err) + } + if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 { + t.Fatal("third call should have one update still") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + if err != nil { + return + } + + // At this point we should now have no updates, because we've progressed the sync + // position. Therefore the update from before will not be sent again. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1)) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 { + t.Fatal("fourth call should have no updates") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1)) + if err != nil { + return + } + + // At this point we should still have no updates, because no new updates have been + // sent. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2)) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { + t.Fatal("fifth call should have no updates") + } +} + +func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { + if len(gots) != len(wants) { + t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) + } + for i := range gots { + g := gots[i] + w := wants[i] + if g.EventID != w.EventID() { + t.Errorf("%s event[%d] event_id mismatch: got %s want %s", msg, i, g.EventID, w.EventID()) + } + if g.Sender != w.Sender() { + t.Errorf("%s event[%d] sender mismatch: got %s want %s", msg, i, g.Sender, w.Sender()) + } + if checkRoomID && g.RoomID != w.RoomID() { + t.Errorf("%s event[%d] room_id mismatch: got %s want %s", msg, i, g.RoomID, w.RoomID()) + } + if g.Type != w.Type() { + t.Errorf("%s event[%d] event type mismatch: got %s want %s", msg, i, g.Type, w.Type()) + } + if g.OriginServerTS != w.OriginServerTS() { + t.Errorf("%s event[%d] origin_server_ts mismatch: got %v want %v", msg, i, g.OriginServerTS, w.OriginServerTS()) + } + if string(g.Content) != string(w.Content()) { + t.Errorf("%s event[%d] content mismatch: got %s want %s", msg, i, string(g.Content), string(w.Content())) + } + if string(g.Unsigned) != string(w.Unsigned()) { + t.Errorf("%s event[%d] unsigned mismatch: got %s want %s", msg, i, string(g.Unsigned), string(w.Unsigned())) + } + if (g.StateKey == nil && w.StateKey() != nil) || (g.StateKey != nil && w.StateKey() == nil) { + t.Errorf("%s event[%d] state_key [not] missing: got %v want %v", msg, i, g.StateKey, w.StateKey()) + continue + } + if g.StateKey != nil { + if !w.StateKeyEquals(*g.StateKey) { + t.Errorf("%s event[%d] state_key mismatch: got %s want %s", msg, i, *g.StateKey, *w.StateKey()) + } + } + } +} + +func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *types.TopologyToken { + tok, err := db.EventPositionInTopology(ctx, eventID) + if err != nil { + t.Fatalf("failed to get EventPositionInTopology: %s", err) + } + tok.Decrement() + return &tok +} + +func reversed(in []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { + out := make([]gomatrixserverlib.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[len(in)-i-1] + } + return out +} diff --git a/syncapi/storage/storage_wasm.go b/syncapi/storage/storage_wasm.go index 43806a012..0886b8c21 100644 --- a/syncapi/storage/storage_wasm.go +++ b/syncapi/storage/storage_wasm.go @@ -18,11 +18,15 @@ import ( "fmt" "net/url" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) // NewPublicRoomsServerDatabase opens a database connection. -func NewSyncServerDatasource(dataSourceName string) (Database, error) { +func NewSyncServerDatasource( + dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam +) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return nil, fmt.Errorf("Cannot use postgres implementation") @@ -31,7 +35,7 @@ func NewSyncServerDatasource(dataSourceName string) (Database, error) { case "postgres": return nil, fmt.Errorf("Cannot use postgres implementation") case "file": - return sqlite3.NewSyncServerDatasource(dataSourceName) + return sqlite3.NewDatabase(dataSourceName) default: return nil, fmt.Errorf("Cannot use postgres implementation") } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go new file mode 100644 index 000000000..0b7d15951 --- /dev/null +++ b/syncapi/storage/tables/interface.go @@ -0,0 +1,135 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type AccountData interface { + InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) + // SelectAccountDataInRange returns a map of room ID to a list of `dataType`. + SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, err error) + SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) +} + +type Invites interface { + InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) + DeleteInviteEvent(ctx context.Context, inviteEventID string) error + // SelectInviteEventsInRange returns a map of room ID to invite events. + SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (map[string]gomatrixserverlib.HeaderedEvent, error) + SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) +} + +type Events interface { + SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter) (map[string]map[string]bool, map[string]types.StreamEvent, error) + SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) + InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error) + // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. + // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. + // Returns up to `limit` events. + SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, error) + // SelectEarlyEvents returns the earliest events in the given room. + SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error) + SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) +} + +// Topology keeps track of the depths and stream positions for all events. +// These positions are used as types.TopologyToken when backfilling events locally. +type Topology interface { + // InsertEventInTopology inserts the given event in the room's topology, based on the event's depth. + // `pos` is the stream position of this event in the events table, and is used to order events which have the same depth. + InsertEventInTopology(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) (err error) + // SelectEventIDsInRange selects the IDs of events whose depths are within a given range in a given room's topological order. + // Events with `minDepth` are *exclusive*, as is the event which has exactly `minDepth`,`maxStreamPos`. + // `maxStreamPos` is only used when events have the same depth as `maxDepth`, which results in events less than `maxStreamPos` being returned. + // Returns an empty slice if no events match the given range. + SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error) + // SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to. + SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) + // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. + SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) +} + +type CurrentRoomState interface { + SelectStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) + UpsertRoomState(ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error + DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error + // SelectCurrentState returns all the current state events for the given room. + SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter) ([]gomatrixserverlib.HeaderedEvent, error) + // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. + SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) + // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. + SelectJoinedUsers(ctx context.Context) (map[string][]string, error) +} + +// BackwardsExtremities keeps track of backwards extremities for a room. +// Backwards extremities are the earliest (DAG-wise) known events which we have +// the entire event JSON. These event IDs are used in federation requests to fetch +// even earlier events. +// +// We persist the previous event IDs as well, one per row, so when we do fetch even +// earlier events we can simply delete rows which referenced it. Consider the graph: +// A +// | Event C has 1 prev_event ID: A. +// B C +// |___| Event D has 2 prev_event IDs: B and C. +// | +// D +// The earliest known event we have is D, so this table has 2 rows. +// A backfill request gives us C but not B. We delete rows where prev_event=C. This +// still means that D is a backwards extremity as we do not have event B. However, event +// C is *also* a backwards extremity at this point as we do not have event A. Later, +// when we fetch event B, we delete rows where prev_event=B. This then removes D as +// a backwards extremity because there are no more rows with event_id=B. +type BackwardsExtremities interface { + // InsertsBackwardExtremity inserts a new backwards extremity. + InsertsBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string) (err error) + // SelectBackwardExtremitiesForRoom retrieves all backwards extremities for the room, as a map of event_id to list of prev_event_ids. + SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) + // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. + DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) +} + +// SendToDevice tracks send-to-device messages which are sent to individual +// clients. Each message gets inserted into this table at the point that we +// receive it from the EDU server. +// +// We're supposed to try and do our best to deliver send-to-device messages +// once, but the only way that we can really guarantee that they have been +// delivered is if the client successfully requests the next sync as given +// in the next_batch. Each time the device syncs, we will request all of the +// updates that either haven't been sent yet, along with all updates that we +// *have* sent but we haven't confirmed to have been received yet. If it's the +// first time we're sending a given update then we update the table to say +// what the "since" parameter was when we tried to send it. +// +// When the client syncs again, if their "since" parameter is *later* than +// the recorded one, we drop the entry from the DB as it's "sent". If the +// sync parameter isn't later then we will keep including the updates in the +// sync response, as the client is seemingly trying to repeat the same /sync. +type SendToDevice interface { + InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error) + SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error) + UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) + DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) + CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error) +} diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 0d8050112..325e75351 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -36,9 +36,9 @@ type Notifier struct { // Protects currPos and userStreams. streamLock *sync.Mutex // The latest sync position - currPos types.PaginationToken - // A map of user_id => UserStream which can be used to wake a given user's /sync request. - userStreams map[string]*UserStream + currPos types.StreamingToken + // A map of user_id => device_id => UserStream which can be used to wake a given user's /sync request. + userDeviceStreams map[string]map[string]*UserDeviceStream // The last time we cleaned out stale entries from the userStreams map lastCleanUpTime time.Time } @@ -46,11 +46,11 @@ type Notifier struct { // NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier(pos types.PaginationToken) *Notifier { +func NewNotifier(pos types.StreamingToken) *Notifier { return &Notifier{ currPos: pos, roomIDToJoinedUsers: make(map[string]userIDSet), - userStreams: make(map[string]*UserStream), + userDeviceStreams: make(map[string]map[string]*UserDeviceStream), streamLock: &sync.Mutex{}, lastCleanUpTime: time.Now(), } @@ -68,7 +68,7 @@ func NewNotifier(pos types.PaginationToken) *Notifier { // event type it handles, leaving other fields as 0. func (n *Notifier) OnNewEvent( ev *gomatrixserverlib.HeaderedEvent, roomID string, userIDs []string, - posUpdate types.PaginationToken, + posUpdate types.StreamingToken, ) { // update the current position then notify relevant /sync streams. // This needs to be done PRIOR to waking up users as they will read this value. @@ -120,10 +120,22 @@ func (n *Notifier) OnNewEvent( } } +func (n *Notifier) OnNewSendToDevice( + userID string, deviceIDs []string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + latestPos := n.currPos.WithUpdates(posUpdate) + n.currPos = latestPos + + n.wakeupUserDevice(userID, deviceIDs, latestPos) +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos -func (n *Notifier) GetListener(req syncRequest) UserStreamListener { +func (n *Notifier) GetListener(req syncRequest) UserDeviceStreamListener { // Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298 // - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID // - Incoming events wake requests for a matching room ID @@ -137,7 +149,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener { n.removeEmptyUserStreams() - return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx) + return n.fetchUserDeviceStream(req.device.UserID, req.device.ID, true).GetListener(req.ctx) } // Load the membership states required to notify users correctly. @@ -151,7 +163,7 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error { } // CurrentPosition returns the current sync position -func (n *Notifier) CurrentPosition() types.PaginationToken { +func (n *Notifier) CurrentPosition() types.StreamingToken { n.streamLock.Lock() defer n.streamLock.Unlock() @@ -173,27 +185,69 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { } } -func (n *Notifier) wakeupUsers(userIDs []string, newPos types.PaginationToken) { +// wakeupUsers will wake up the sync strems for all of the devices for all of the +// specified user IDs. +func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { for _, userID := range userIDs { - stream := n.fetchUserStream(userID, false) - if stream != nil { + for _, stream := range n.fetchUserStreams(userID) { + if stream == nil { + continue + } stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } } } -// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, +// wakeupUserDevice will wake up the sync stream for a specific user device. Other +// device streams will be left alone. +// nolint:unused +func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) { + for _, deviceID := range deviceIDs { + if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } + } +} + +// fetchUserDeviceStream retrieves a stream unique to the given device. If makeIfNotExists is true, +// a stream will be made for this device if one doesn't exist and it will be returned. This +// function does not wait for data to be available on the stream. +// NB: Callers should have locked the mutex before calling this function. +func (n *Notifier) fetchUserDeviceStream(userID, deviceID string, makeIfNotExists bool) *UserDeviceStream { + _, ok := n.userDeviceStreams[userID] + if !ok { + if !makeIfNotExists { + return nil + } + n.userDeviceStreams[userID] = map[string]*UserDeviceStream{} + } + stream, ok := n.userDeviceStreams[userID][deviceID] + if !ok { + if !makeIfNotExists { + return nil + } + // TODO: Unbounded growth of streams (1 per user) + if stream = NewUserDeviceStream(userID, deviceID, n.currPos); stream != nil { + n.userDeviceStreams[userID][deviceID] = stream + } + } + return stream +} + +// fetchUserStreams retrieves all streams for the given user. If makeIfNotExists is true, // a stream will be made for this user if one doesn't exist and it will be returned. This // function does not wait for data to be available on the stream. // NB: Callers should have locked the mutex before calling this function. -func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { - stream, ok := n.userStreams[userID] - if !ok && makeIfNotExists { - // TODO: Unbounded growth of streams (1 per user) - stream = NewUserStream(userID, n.currPos) - n.userStreams[userID] = stream +func (n *Notifier) fetchUserStreams(userID string) []*UserDeviceStream { + user, ok := n.userDeviceStreams[userID] + if !ok { + return []*UserDeviceStream{} } - return stream + streams := []*UserDeviceStream{} + for _, stream := range user { + streams = append(streams, stream) + } + return streams } // Not thread-safe: must be called on the OnNewEvent goroutine only @@ -236,9 +290,14 @@ func (n *Notifier) removeEmptyUserStreams() { n.lastCleanUpTime = now deleteBefore := now.Add(-5 * time.Minute) - for key, value := range n.userStreams { - if value.TimeOfLastNonEmpty().Before(deleteBefore) { - delete(n.userStreams, key) + for user, byUser := range n.userDeviceStreams { + for device, stream := range byUser { + if stream.TimeOfLastNonEmpty().Before(deleteBefore) { + delete(n.userDeviceStreams[user], device) + } + if len(n.userDeviceStreams[user]) == 0 { + delete(n.userDeviceStreams, user) + } } } } diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 350d757c6..ecc4fcbfc 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -22,9 +22,8 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -33,40 +32,22 @@ var ( randomMessageEvent gomatrixserverlib.HeaderedEvent aliceInviteBobEvent gomatrixserverlib.HeaderedEvent bobLeaveEvent gomatrixserverlib.HeaderedEvent - syncPositionVeryOld types.PaginationToken - syncPositionBefore types.PaginationToken - syncPositionAfter types.PaginationToken - syncPositionNewEDU types.PaginationToken - syncPositionAfter2 types.PaginationToken + syncPositionVeryOld = types.NewStreamToken(5, 0) + syncPositionBefore = types.NewStreamToken(11, 0) + syncPositionAfter = types.NewStreamToken(12, 0) + syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1) + syncPositionAfter2 = types.NewStreamToken(13, 0) ) var ( - roomID = "!test:localhost" - alice = "@alice:localhost" - bob = "@bob:localhost" + roomID = "!test:localhost" + alice = "@alice:localhost" + aliceDev = "alicedevice" + bob = "@bob:localhost" + bobDev = "bobdev" ) func init() { - baseSyncPos := types.PaginationToken{ - PDUPosition: 0, - EDUTypingPosition: 0, - } - - syncPositionVeryOld = baseSyncPos - syncPositionVeryOld.PDUPosition = 5 - - syncPositionBefore = baseSyncPos - syncPositionBefore.PDUPosition = 11 - - syncPositionAfter = baseSyncPos - syncPositionAfter.PDUPosition = 12 - - syncPositionNewEDU = syncPositionAfter - syncPositionNewEDU.EDUTypingPosition = 1 - - syncPositionAfter2 = baseSyncPos - syncPositionAfter2.PDUPosition = 13 - var err error err = json.Unmarshal([]byte(`{ "_room_version": "1", @@ -118,16 +99,20 @@ func init() { } } +func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { + if got.String() != want.String() { + t.Fatalf("mustEqualPositions got %s want %s", got.String(), want.String()) + } +} + // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { n := NewNotifier(syncPositionBefore) - pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionVeryOld)) + pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) if err != nil { t.Fatalf("TestImmediateNotification error: %s", err) } - if pos != syncPositionBefore { - t.Fatalf("TestImmediateNotification want %v, got %v", syncPositionBefore, pos) - } + mustEqualPositions(t, pos, syncPositionBefore) } // Test that new events to a joined room unblocks the request. @@ -140,17 +125,15 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndJoinedToRoom error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestNewEventAndJoinedToRoom want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 1) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) @@ -158,6 +141,43 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { wg.Wait() } +func TestCorrectStream(t *testing.T) { + n := NewNotifier(syncPositionBefore) + stream := lockedFetchUserStream(n, bob, bobDev) + if stream.UserID != bob { + t.Fatalf("expected user %q, got %q", bob, stream.UserID) + } + if stream.DeviceID != bobDev { + t.Fatalf("expected device %q, got %q", bobDev, stream.DeviceID) + } +} + +func TestCorrectStreamWakeup(t *testing.T) { + n := NewNotifier(syncPositionBefore) + awoken := make(chan string) + + streamone := lockedFetchUserStream(n, alice, "one") + streamtwo := lockedFetchUserStream(n, alice, "two") + + go func() { + select { + case <-streamone.signalChannel: + awoken <- "one" + case <-streamtwo.signalChannel: + awoken <- "two" + } + }() + + time.Sleep(1 * time.Second) + + wake := "two" + n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter) + + if result := <-awoken; result != wake { + t.Fatalf("expected to wake %q, got %q", wake, result) + } +} + // Test that an invite unblocks the request func TestNewInviteEventForUser(t *testing.T) { n := NewNotifier(syncPositionBefore) @@ -168,17 +188,15 @@ func TestNewInviteEventForUser(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 1) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter) @@ -196,17 +214,15 @@ func TestEDUWakeup(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } - if pos != syncPositionNewEDU { - t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionNewEDU, pos) - } + mustEqualPositions(t, pos, syncPositionNewEDU) wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 1) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU) @@ -224,20 +240,18 @@ func TestMultipleRequestWakeup(t *testing.T) { var wg sync.WaitGroup wg.Add(3) poll := func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestMultipleRequestWakeup error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestMultipleRequestWakeup want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) wg.Done() } go poll() go poll() go poll() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 3) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) @@ -264,38 +278,34 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // Make bob leave the room leaveWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) leaveWG.Done() }() - bobStream := lockedFetchUserStream(n, bob) + bobStream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(bobStream, 1) n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter) leaveWG.Wait() // send an event into the room. Make sure alice gets it. Bob should not. var aliceWG sync.WaitGroup - aliceStream := lockedFetchUserStream(n, alice) + aliceStream := lockedFetchUserStream(n, alice, aliceDev) aliceWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter)) + pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionAfter)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } - if pos != syncPositionAfter2 { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter2, pos) - } + mustEqualPositions(t, pos, syncPositionAfter2) aliceWG.Done() }() go func() { // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) - _, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) + _, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) if err == nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") } @@ -312,13 +322,13 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { time.Sleep(1 * time.Millisecond) } -func waitForEvents(n *Notifier, req syncRequest) (types.PaginationToken, error) { +func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) { listener := n.GetListener(req) defer listener.Close() select { case <-time.After(5 * time.Second): - return types.PaginationToken{}, fmt.Errorf( + return types.StreamingToken{}, fmt.Errorf( "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since, ) case <-listener.GetNotifyChannel(*req.since): @@ -328,7 +338,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.PaginationToken, error) } // Wait until something is Wait()ing on the user stream. -func waitForBlocking(s *UserStream, numBlocking uint) { +func waitForBlocking(s *UserDeviceStream, numBlocking uint) { for numBlocking != s.NumWaiting() { // This is horrible but I don't want to add a signalling mechanism JUST for testing. time.Sleep(1 * time.Microsecond) @@ -337,16 +347,19 @@ func waitForBlocking(s *UserStream, numBlocking uint) { // lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock. // A new stream is made if it doesn't exist already. -func lockedFetchUserStream(n *Notifier, userID string) *UserStream { +func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStream { n.streamLock.Lock() defer n.streamLock.Unlock() - return n.fetchUserStream(userID, true) + return n.fetchUserDeviceStream(userID, deviceID, true) } -func newTestSyncRequest(userID string, since types.PaginationToken) syncRequest { +func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syncRequest { return syncRequest{ - device: authtypes.Device{UserID: userID}, + device: userapi.Device{ + UserID: userID, + ID: deviceID, + }, timeout: 1 * time.Minute, since: &since, wantFullState: false, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index f2e199d23..5dd92c853 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -16,13 +16,13 @@ package sync import ( "context" + "encoding/json" "net/http" "strconv" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -30,24 +30,52 @@ import ( const defaultSyncTimeout = time.Duration(0) const defaultTimelineLimit = 20 +type filter struct { + Room struct { + Timeline struct { + Limit *int `json:"limit"` + } `json:"timeline"` + } `json:"room"` +} + // syncRequest represents a /sync request, with sensible defaults/sanity checks applied. type syncRequest struct { ctx context.Context - device authtypes.Device + device userapi.Device limit int timeout time.Duration - since *types.PaginationToken // nil means that no since token was supplied + since *types.StreamingToken // nil means that no since token was supplied wantFullState bool log *log.Entry } -func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, error) { +func newSyncRequest(req *http.Request, device userapi.Device) (*syncRequest, error) { timeout := getTimeout(req.URL.Query().Get("timeout")) fullState := req.URL.Query().Get("full_state") wantFullState := fullState != "" && fullState != "false" - since, err := getPaginationToken(req.URL.Query().Get("since")) - if err != nil { - return nil, err + var since *types.StreamingToken + sinceStr := req.URL.Query().Get("since") + if sinceStr != "" { + tok, err := types.NewStreamTokenFromString(sinceStr) + if err != nil { + return nil, err + } + since = &tok + } + if since == nil { + tok := types.NewStreamToken(0, 0) + since = &tok + } + timelineLimit := defaultTimelineLimit + // TODO: read from stored filters too + filterQuery := req.URL.Query().Get("filter") + if filterQuery != "" && filterQuery[0] == '{' { + // attempt to parse the timeline limit at least + var f filter + err := json.Unmarshal([]byte(filterQuery), &f) + if err == nil && f.Room.Timeline.Limit != nil { + timelineLimit = *f.Room.Timeline.Limit + } } // TODO: Additional query params: set_presence, filter return &syncRequest{ @@ -56,7 +84,7 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e timeout: timeout, since: since, wantFullState: wantFullState, - limit: defaultTimelineLimit, // TODO: read from filter + limit: timelineLimit, log: util.GetLogger(req.Context()), }, nil } @@ -71,16 +99,3 @@ func getTimeout(timeoutMS string) time.Duration { } return time.Duration(i) * time.Millisecond } - -// getSyncStreamPosition tries to parse a 'since' token taken from the API to a -// types.PaginationToken. If the string is empty then (nil, nil) is returned. -// There are two forms of tokens: The full length form containing all PDU and EDU -// positions separated by "_", and the short form containing only the PDU -// position. Short form can be used for, e.g., `prev_batch` tokens. -func getPaginationToken(since string) (*types.PaginationToken, error) { - if since == "" { - return nil, nil - } - - return types.NewPaginationTokenFromString(since) -} diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 69efd8aa8..743c63a62 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,14 +17,14 @@ package sync import ( + "context" "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -30,24 +32,23 @@ import ( // RequestPool manages HTTP long-poll connections for /sync type RequestPool struct { - db storage.Database - accountDB accounts.Database - notifier *Notifier + db storage.Database + userAPI userapi.UserInternalAPI + notifier *Notifier } // NewRequestPool makes a new RequestPool -func NewRequestPool(db storage.Database, n *Notifier, adb accounts.Database) *RequestPool { - return &RequestPool{db, adb, n} +func NewRequestPool(db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI) *RequestPool { + return &RequestPool{db, userAPI, n} } // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be // called in a dedicated goroutine for this request. This function will block the goroutine // until a response is ready, or it times out. -func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtypes.Device) util.JSONResponse { +func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.Device) util.JSONResponse { var syncData *types.Response // Extract values from request - userID := device.UserID syncReq, err := newSyncRequest(req, *device) if err != nil { return util.JSONResponse{ @@ -55,15 +56,18 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype JSON: jsonerror.Unknown(err.Error()), } } + logger := util.GetLogger(req.Context()).WithFields(log.Fields{ - "userID": userID, - "since": syncReq.since, - "timeout": syncReq.timeout, + "user_id": device.UserID, + "device_id": device.ID, + "since": syncReq.since, + "timeout": syncReq.timeout, + "limit": syncReq.limit, }) currPos := rp.notifier.CurrentPosition() - if shouldReturnImmediately(syncReq) { + if rp.shouldReturnImmediately(syncReq) { syncData, err = rp.currentSyncForUser(*syncReq, currPos) if err != nil { logger.WithError(err).Error("rp.currentSyncForUser failed") @@ -115,7 +119,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype // response. This ensures that we don't waste the hard work // of calculating the sync only to get timed out before we // can respond - syncData, err = rp.currentSyncForUser(*syncReq, currPos) if err != nil { logger.WithError(err).Error("rp.currentSyncForUser failed") @@ -132,23 +135,64 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype } } -func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.PaginationToken) (res *types.Response, err error) { - // TODO: handle ignored users - if req.since == nil { - res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) - } else { - res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState) +func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { + res = types.NewResponse() + + since := types.NewStreamToken(0, 0) + if req.since != nil { + since = *req.since } + // See if we have any new tasks to do for the send-to-device messaging. + events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, since) + if err != nil { + return nil, err + } + + // TODO: handle ignored users + if req.since == nil { + res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit) + } else { + res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) + } if err != nil { return } accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead - res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter) + res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) + if err != nil { + return + } + + // Before we return the sync response, make sure that we take action on + // any send-to-device database updates or deletions that we need to do. + // Then add the updates into the sync response. + if len(updates) > 0 || len(deletions) > 0 { + // Handle the updates and deletions in the database. + err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since) + if err != nil { + return + } + } + if len(events) > 0 { + // Add the updates into the sync response. + for _, event := range events { + res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent) + } + + // Get the next_batch from the sync response and increase the + // EDU counter. + if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil { + pos.Positions[1]++ + res.NextBatch = pos.String() + } + } + return } +// nolint:gocyclo func (rp *RequestPool) appendAccountData( data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition, accountDataFilter *gomatrixserverlib.EventFilter, @@ -158,67 +202,101 @@ func (rp *RequestPool) appendAccountData( // data keys were set between two message. This isn't a huge issue since the // duplicate data doesn't represent a huge quantity of data, but an optimisation // here would be making sure each data is sent only once to the client. - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return nil, err - } - if req.since == nil { // If this is the initial sync, we don't need to check if a data has // already been sent. Instead, we send the whole batch. - var global []gomatrixserverlib.ClientEvent - var rooms map[string][]gomatrixserverlib.ClientEvent - global, rooms, err = rp.accountDB.GetAccountData(req.ctx, localpart) - if err != nil { + dataReq := &userapi.QueryAccountDataRequest{ + UserID: userID, + } + dataRes := &userapi.QueryAccountDataResponse{} + if err := rp.userAPI.QueryAccountData(req.ctx, dataReq, dataRes); err != nil { return nil, err } - data.AccountData.Events = global - + for datatype, databody := range dataRes.GlobalAccountData { + data.AccountData.Events = append( + data.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: datatype, + Content: gomatrixserverlib.RawJSON(databody), + }, + ) + } for r, j := range data.Rooms.Join { - if len(rooms[r]) > 0 { - j.AccountData.Events = rooms[r] + for datatype, databody := range dataRes.RoomAccountData[r] { + j.AccountData.Events = append( + j.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: datatype, + Content: gomatrixserverlib.RawJSON(databody), + }, + ) data.Rooms.Join[r] = j } } - return data, nil } + r := types.Range{ + From: req.since.PDUPosition(), + To: currentPos, + } + // If both positions are the same, it means that the data was saved after the + // latest room event. In that case, we need to decrement the old position as + // results are exclusive of Low. + if r.Low() == r.High() { + r.From-- + } + // Sync is not initial, get all account data since the latest sync dataTypes, err := rp.db.GetAccountDataInRange( - req.ctx, userID, - types.StreamPosition(req.since.PDUPosition), types.StreamPosition(currentPos), - accountDataFilter, + req.ctx, userID, r, accountDataFilter, ) if err != nil { return nil, err } if len(dataTypes) == 0 { - return data, nil + // TODO: this fixes the sytest but is it the right thing to do? + dataTypes[""] = []string{"m.push_rules"} } // Iterate over the rooms for roomID, dataTypes := range dataTypes { - events := []gomatrixserverlib.ClientEvent{} // Request the missing data from the database for _, dataType := range dataTypes { - event, err := rp.accountDB.GetAccountDataByType( - req.ctx, localpart, roomID, dataType, - ) - if err != nil { - return nil, err + dataReq := userapi.QueryAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: dataType, + } + dataRes := userapi.QueryAccountDataResponse{} + err = rp.userAPI.QueryAccountData(req.ctx, &dataReq, &dataRes) + if err != nil { + continue + } + if roomID == "" { + if globalData, ok := dataRes.GlobalAccountData[dataType]; ok { + data.AccountData.Events = append( + data.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: gomatrixserverlib.RawJSON(globalData), + }, + ) + } + } else { + if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok { + joinData := data.Rooms.Join[roomID] + joinData.AccountData.Events = append( + joinData.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: gomatrixserverlib.RawJSON(roomData), + }, + ) + data.Rooms.Join[roomID] = joinData + } } - events = append(events, *event) - } - - // Append the data to the response - if len(roomID) > 0 { - jr := data.Rooms.Join[roomID] - jr.AccountData.Events = events - data.Rooms.Join[roomID] = jr - } else { - data.AccountData.Events = events } } @@ -228,6 +306,10 @@ func (rp *RequestPool) appendAccountData( // shouldReturnImmediately returns whether the /sync request is an initial sync, // or timeout=0, or full_state=true, in any of the cases the request should // return immediately. -func shouldReturnImmediately(syncReq *syncRequest) bool { - return syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState +func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool { + if syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState { + return true + } + waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID) + return werr == nil && waiting } diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go index 88867005e..ff9a4d003 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/sync/userstream.go @@ -23,36 +23,38 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" ) -// UserStream represents a communication mechanism between the /sync request goroutine +// UserDeviceStream represents a communication mechanism between the /sync request goroutine // and the underlying sync server goroutines. // Goroutines can get a UserStreamListener to wait for updates, and can Broadcast() // updates. -type UserStream struct { - UserID string +type UserDeviceStream struct { + UserID string + DeviceID string // The lock that protects changes to this struct lock sync.Mutex // Closed when there is an update. signalChannel chan struct{} // The last sync position that there may have been an update for the user - pos types.PaginationToken + pos types.StreamingToken // The last time when we had some listeners waiting timeOfLastChannel time.Time // The number of listeners waiting numWaiting uint } -// UserStreamListener allows a sync request to wait for updates for a user. -type UserStreamListener struct { - userStream *UserStream +// UserDeviceStreamListener allows a sync request to wait for updates for a user. +type UserDeviceStreamListener struct { + userStream *UserDeviceStream // Whether the stream has been closed hasClosed bool } -// NewUserStream creates a new user stream -func NewUserStream(userID string, currPos types.PaginationToken) *UserStream { - return &UserStream{ +// NewUserDeviceStream creates a new user stream +func NewUserDeviceStream(userID, deviceID string, currPos types.StreamingToken) *UserDeviceStream { + return &UserDeviceStream{ UserID: userID, + DeviceID: deviceID, timeOfLastChannel: time.Now(), pos: currPos, signalChannel: make(chan struct{}), @@ -62,18 +64,18 @@ func NewUserStream(userID string, currPos types.PaginationToken) *UserStream { // GetListener returns UserStreamListener that a sync request can use to wait // for new updates with. // UserStreamListener must be closed -func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { +func (s *UserDeviceStream) GetListener(ctx context.Context) UserDeviceStreamListener { s.lock.Lock() defer s.lock.Unlock() s.numWaiting++ // We decrement when UserStreamListener is closed - listener := UserStreamListener{ + listener := UserDeviceStreamListener{ userStream: s, } // Lets be a bit paranoid here and check that Close() is being called - runtime.SetFinalizer(&listener, func(l *UserStreamListener) { + runtime.SetFinalizer(&listener, func(l *UserDeviceStreamListener) { if !l.hasClosed { l.Close() } @@ -83,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { } // Broadcast a new sync position for this user. -func (s *UserStream) Broadcast(pos types.PaginationToken) { +func (s *UserDeviceStream) Broadcast(pos types.StreamingToken) { s.lock.Lock() defer s.lock.Unlock() @@ -96,7 +98,7 @@ func (s *UserStream) Broadcast(pos types.PaginationToken) { // NumWaiting returns the number of goroutines waiting for waiting for updates. // Used for metrics and testing. -func (s *UserStream) NumWaiting() uint { +func (s *UserDeviceStream) NumWaiting() uint { s.lock.Lock() defer s.lock.Unlock() return s.numWaiting @@ -105,7 +107,7 @@ func (s *UserStream) NumWaiting() uint { // TimeOfLastNonEmpty returns the last time that the number of waiting listeners // was non-empty, may be time.Now() if number of waiting listeners is currently // non-empty. -func (s *UserStream) TimeOfLastNonEmpty() time.Time { +func (s *UserDeviceStream) TimeOfLastNonEmpty() time.Time { s.lock.Lock() defer s.lock.Unlock() @@ -116,9 +118,9 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time { return s.timeOfLastChannel } -// GetStreamPosition returns last sync position which the UserStream was +// GetSyncPosition returns last sync position which the UserStream was // notified about -func (s *UserStreamListener) GetSyncPosition() types.PaginationToken { +func (s *UserDeviceStreamListener) GetSyncPosition() types.StreamingToken { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -130,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.PaginationToken { // sincePos specifies from which point we want to be notified about. If there // has already been an update after sincePos we'll return a closed channel // immediately. -func (s *UserStreamListener) GetNotifyChannel(sincePos types.PaginationToken) <-chan struct{} { +func (s *UserDeviceStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -147,7 +149,7 @@ func (s *UserStreamListener) GetNotifyChannel(sincePos types.PaginationToken) <- } // Close cleans up resources used -func (s *UserStreamListener) Close() { +func (s *UserDeviceStreamListener) Close() { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 1535d2b13..caf91e27e 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -17,32 +17,32 @@ package syncapi import ( "context" + "github.com/Shopify/sarama" + "github.com/gorilla/mux" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/common/basecomponent" - "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/syncapi/consumers" "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" ) -// SetupSyncAPIComponent sets up and registers HTTP handlers for the SyncAPI +// AddPublicRoutes sets up and registers HTTP handlers for the SyncAPI // component. -func SetupSyncAPIComponent( - base *basecomponent.BaseDendrite, - deviceDB devices.Database, - accountsDB accounts.Database, - queryAPI api.RoomserverQueryAPI, +func AddPublicRoutes( + router *mux.Router, + consumer sarama.Consumer, + userAPI userapi.UserInternalAPI, + rsAPI api.RoomserverInternalAPI, federation *gomatrixserverlib.FederationClient, cfg *config.Dendrite, ) { - syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI)) + syncDB, err := storage.NewSyncServerDatasource(string(cfg.Database.SyncAPI), cfg.DbProperties()) if err != nil { logrus.WithError(err).Panicf("failed to connect to sync db") } @@ -58,28 +58,35 @@ func SetupSyncAPIComponent( logrus.WithError(err).Panicf("failed to start notifier") } - requestPool := sync.NewRequestPool(syncDB, notifier, accountsDB) + requestPool := sync.NewRequestPool(syncDB, notifier, userAPI) roomConsumer := consumers.NewOutputRoomEventConsumer( - base.Cfg, base.KafkaConsumer, notifier, syncDB, queryAPI, + cfg, consumer, notifier, syncDB, rsAPI, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") } clientConsumer := consumers.NewOutputClientDataConsumer( - base.Cfg, base.KafkaConsumer, notifier, syncDB, + cfg, consumer, notifier, syncDB, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") } typingConsumer := consumers.NewOutputTypingEventConsumer( - base.Cfg, base.KafkaConsumer, notifier, syncDB, + cfg, consumer, notifier, syncDB, ) if err = typingConsumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start typing server consumer") + logrus.WithError(err).Panicf("failed to start typing consumer") } - routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, federation, queryAPI, cfg) + sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( + cfg, consumer, notifier, syncDB, + ) + if err = sendToDeviceConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start send-to-device consumer") + } + + routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 718906ecd..1094416a1 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -23,22 +23,23 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" ) var ( - // ErrInvalidPaginationTokenType is returned when an attempt at creating a - // new instance of PaginationToken with an invalid type (i.e. neither "s" + // ErrInvalidSyncTokenType is returned when an attempt at creating a + // new instance of SyncToken with an invalid type (i.e. neither "s" // nor "t"). - ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)") - // ErrInvalidPaginationTokenLen is returned when the pagination token is an + ErrInvalidSyncTokenType = fmt.Errorf("Sync token has an unknown prefix (should be either s or t)") + // ErrInvalidSyncTokenLen is returned when the pagination token is an // invalid length - ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length") + ErrInvalidSyncTokenLen = fmt.Errorf("Sync token has an invalid length") ) // StreamPosition represents the offset in the sync stream a client is at. type StreamPosition int64 -// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. +// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. type StreamEvent struct { gomatrixserverlib.HeaderedEvent StreamPosition StreamPosition @@ -46,112 +47,235 @@ type StreamEvent struct { ExcludeFromSync bool } -// PaginationTokenType represents the type of a pagination token. +// Range represents a range between two stream positions. +type Range struct { + // From is the position the client has already received. + From StreamPosition + // To is the position the client is going towards. + To StreamPosition + // True if the client is going backwards + Backwards bool +} + +// Low returns the low number of the range. +// This represents the position the client already has and hence is exclusive. +func (r *Range) Low() StreamPosition { + if !r.Backwards { + return r.From + } + return r.To +} + +// High returns the high number of the range +// This represents the position the client is going towards and hence is inclusive. +func (r *Range) High() StreamPosition { + if !r.Backwards { + return r.To + } + return r.From +} + +// SyncTokenType represents the type of a sync token. // It can be either "s" (representing a position in the whole stream of events) // or "t" (representing a position in a room's topology/depth). -type PaginationTokenType string +type SyncTokenType string const ( - // PaginationTokenTypeStream represents a position in the server's whole + // SyncTokenTypeStream represents a position in the server's whole // stream of events - PaginationTokenTypeStream PaginationTokenType = "s" - // PaginationTokenTypeTopology represents a position in a room's topology. - PaginationTokenTypeTopology PaginationTokenType = "t" + SyncTokenTypeStream SyncTokenType = "s" + // SyncTokenTypeTopology represents a position in a room's topology. + SyncTokenTypeTopology SyncTokenType = "t" ) -// PaginationToken represents a pagination token, used for interactions with -// /sync or /messages, for example. -type PaginationToken struct { - //Position StreamPosition - Type PaginationTokenType - PDUPosition StreamPosition - EDUTypingPosition StreamPosition +type StreamingToken struct { + syncToken } -// NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x" -// represents the type of a pagination token and "yyyy..." the token itself, and -// parses it in order to create a new instance of PaginationToken. Returns an -// error if the token couldn't be parsed into an int64, or if the token type -// isn't a known type (returns ErrInvalidPaginationTokenType in the latter -// case). -func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) { - if len(s) == 0 { - return nil, ErrInvalidPaginationTokenLen - } +func (t *StreamingToken) PDUPosition() StreamPosition { + return t.Positions[0] +} +func (t *StreamingToken) EDUPosition() StreamPosition { + return t.Positions[1] +} +func (t *StreamingToken) String() string { + return t.syncToken.String() +} - token = new(PaginationToken) - var positions []string - - switch t := PaginationTokenType(s[:1]); t { - case PaginationTokenTypeStream, PaginationTokenTypeTopology: - token.Type = t - positions = strings.Split(s[1:], "_") - default: - token.Type = PaginationTokenTypeStream - positions = strings.Split(s, "_") - } - - // Try to get the PDU position. - if len(positions) >= 1 { - if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil { - return nil, err - } else if pduPos < 0 { - return nil, errors.New("negative PDU position not allowed") - } else { - token.PDUPosition = StreamPosition(pduPos) +// IsAfter returns true if ANY position in this token is greater than `other`. +func (t *StreamingToken) IsAfter(other StreamingToken) bool { + for i := range other.Positions { + if t.Positions[i] > other.Positions[i] { + return true } } + return false +} - // Try to get the typing position. - if len(positions) >= 2 { - if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil { - return nil, err - } else if typPos < 0 { - return nil, errors.New("negative EDU typing position not allowed") - } else { - token.EDUTypingPosition = StreamPosition(typPos) +// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. +// If the latter StreamingToken contains a field that is not 0, it is considered an update, +// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called. +func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) { + ret.Type = t.Type + ret.Positions = make([]StreamPosition, len(t.Positions)) + for i := range t.Positions { + ret.Positions[i] = t.Positions[i] + if other.Positions[i] == 0 { + continue } - } - - return -} - -// NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a -// StreamPosition and returns an instance of PaginationToken. -func NewPaginationTokenFromTypeAndPosition( - t PaginationTokenType, pdupos StreamPosition, typpos StreamPosition, -) (p *PaginationToken) { - return &PaginationToken{ - Type: t, - PDUPosition: pdupos, - EDUTypingPosition: typpos, - } -} - -// String translates a PaginationToken to a string of the "xyyyy..." (see -// NewPaginationToken to know what it represents). -func (p *PaginationToken) String() string { - return fmt.Sprintf("%s%d_%d", p.Type, p.PDUPosition, p.EDUTypingPosition) -} - -// WithUpdates returns a copy of the PaginationToken with updates applied from another PaginationToken. -// If the latter PaginationToken contains a field that is not 0, it is considered an update, -// and its value will replace the corresponding value in the PaginationToken on which WithUpdates is called. -func (pt *PaginationToken) WithUpdates(other PaginationToken) PaginationToken { - ret := *pt - if other.PDUPosition != 0 { - ret.PDUPosition = other.PDUPosition - } - if other.EDUTypingPosition != 0 { - ret.EDUTypingPosition = other.EDUTypingPosition + ret.Positions[i] = other.Positions[i] } return ret } -// IsAfter returns whether one PaginationToken refers to states newer than another PaginationToken. -func (sp *PaginationToken) IsAfter(other PaginationToken) bool { - return sp.PDUPosition > other.PDUPosition || - sp.EDUTypingPosition > other.EDUTypingPosition +type TopologyToken struct { + syncToken +} + +func (t *TopologyToken) Depth() StreamPosition { + return t.Positions[0] +} +func (t *TopologyToken) PDUPosition() StreamPosition { + return t.Positions[1] +} +func (t *TopologyToken) StreamToken() StreamingToken { + return NewStreamToken(t.PDUPosition(), 0) +} +func (t *TopologyToken) String() string { + return t.syncToken.String() +} + +// Decrement the topology token to one event earlier. +func (t *TopologyToken) Decrement() { + depth := t.Positions[0] + pduPos := t.Positions[1] + if depth-1 <= 0 { + // nothing can be lower than this + depth = 1 + } else { + // this assumes that we will never have 1000 events all with the same + // depth. TODO: work out what the right PDU position is to use, probably needs a db hit. + depth-- + pduPos += 1000 + } + // The lowest token value is 1, therefore we need to manually set it to that + // value if we're below it. + if depth < 1 { + depth = 1 + } + t.Positions = []StreamPosition{ + depth, pduPos, + } +} + +// NewSyncTokenFromString takes a string of the form "xyyyy..." where "x" +// represents the type of a pagination token and "yyyy..." the token itself, and +// parses it in order to create a new instance of SyncToken. Returns an +// error if the token couldn't be parsed into an int64, or if the token type +// isn't a known type (returns ErrInvalidSyncTokenType in the latter +// case). +func newSyncTokenFromString(s string) (token *syncToken, err error) { + if len(s) == 0 { + return nil, ErrInvalidSyncTokenLen + } + + token = new(syncToken) + var positions []string + + switch t := SyncTokenType(s[:1]); t { + case SyncTokenTypeStream, SyncTokenTypeTopology: + token.Type = t + positions = strings.Split(s[1:], "_") + default: + return nil, ErrInvalidSyncTokenType + } + + for _, pos := range positions { + if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil { + return nil, err + } else if posInt < 0 { + return nil, errors.New("negative position not allowed") + } else { + token.Positions = append(token.Positions, StreamPosition(posInt)) + } + } + return +} + +// NewTopologyToken creates a new sync token for /messages +func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken { + if depth < 0 { + depth = 1 + } + return TopologyToken{ + syncToken: syncToken{ + Type: SyncTokenTypeTopology, + Positions: []StreamPosition{depth, streamPos}, + }, + } +} +func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) { + t, err := newSyncTokenFromString(tok) + if err != nil { + return + } + if t.Type != SyncTokenTypeTopology { + err = fmt.Errorf("token %s is not a topology token", tok) + return + } + if len(t.Positions) < 2 { + err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions)) + return + } + return TopologyToken{ + syncToken: *t, + }, nil +} + +// NewStreamToken creates a new sync token for /sync +func NewStreamToken(pduPos, eduPos StreamPosition) StreamingToken { + return StreamingToken{ + syncToken: syncToken{ + Type: SyncTokenTypeStream, + Positions: []StreamPosition{pduPos, eduPos}, + }, + } +} +func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { + t, err := newSyncTokenFromString(tok) + if err != nil { + return + } + if t.Type != SyncTokenTypeStream { + err = fmt.Errorf("token %s is not a streaming token", tok) + return + } + if len(t.Positions) < 2 { + err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions)) + return + } + return StreamingToken{ + syncToken: *t, + }, nil +} + +// syncToken represents a syncapi token, used for interactions with +// /sync or /messages, for example. +type syncToken struct { + Type SyncTokenType + // A list of stream positions, their meanings vary depending on the token type. + Positions []StreamPosition +} + +// String translates a SyncToken to a string of the "xyyyy..." (see +// NewSyncToken to know what it represents). +func (p *syncToken) String() string { + posStr := make([]string, len(p.Positions)) + for i := range p.Positions { + posStr[i] = strconv.FormatInt(int64(p.Positions[i]), 10) + } + + return fmt.Sprintf("%s%s", p.Type, strings.Join(posStr, "_")) } // PrevEventRef represents a reference to a previous event in a state event upgrade @@ -175,13 +299,14 @@ type Response struct { Invite map[string]InviteResponse `json:"invite"` Leave map[string]LeaveResponse `json:"leave"` } `json:"rooms"` + ToDevice struct { + Events []gomatrixserverlib.SendToDeviceEvent `json:"events"` + } `json:"to_device"` } // NewResponse creates an empty response with initialised maps. -func NewResponse(token PaginationToken) *Response { - res := Response{ - NextBatch: token.String(), - } +func NewResponse() *Response { + res := Response{} // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. res.Rooms.Join = make(map[string]JoinResponse) @@ -194,14 +319,7 @@ func NewResponse(token PaginationToken) *Response { // This also applies to NewJoinResponse, NewInviteResponse and NewLeaveResponse. res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) - - // Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume - // we'll always return a stream token. - res.NextBatch = NewPaginationTokenFromTypeAndPosition( - PaginationTokenTypeStream, - StreamPosition(token.PDUPosition), - StreamPosition(token.EDUTypingPosition), - ).String() + res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0) return &res } @@ -213,7 +331,8 @@ func (r *Response) IsEmpty() bool { len(r.Rooms.Invite) == 0 && len(r.Rooms.Leave) == 0 && len(r.AccountData.Events) == 0 && - len(r.Presence.Events) == 0 + len(r.Presence.Events) == 0 && + len(r.ToDevice.Events) == 0 } // JoinResponse represents a /sync response for a room which is under the 'join' key. @@ -247,14 +366,17 @@ func NewJoinResponse() *JoinResponse { // InviteResponse represents a /sync response for a room which is under the 'invite' key. type InviteResponse struct { InviteState struct { - Events []gomatrixserverlib.ClientEvent `json:"events"` + Events json.RawMessage `json:"events"` } `json:"invite_state"` } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse() *InviteResponse { +func NewInviteResponse(event gomatrixserverlib.HeaderedEvent) *InviteResponse { res := InviteResponse{} - res.InviteState.Events = make([]gomatrixserverlib.ClientEvent, 0) + res.InviteState.Events = json.RawMessage{'[', ']'} + if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { + res.InviteState.Events = json.RawMessage(inviteRoomState.Raw) + } return &res } @@ -277,3 +399,13 @@ func NewLeaveResponse() *LeaveResponse { res.Timeline.Events = make([]gomatrixserverlib.ClientEvent, 0) return &res } + +type SendToDeviceNID int + +type SendToDeviceEvent struct { + gomatrixserverlib.SendToDeviceEvent + ID SendToDeviceNID + UserID string + DeviceID string + SentByToken *StreamingToken +} diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index f4c84e0d1..1e27a8e32 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -2,26 +2,11 @@ package types import "testing" -func TestNewPaginationTokenFromString(t *testing.T) { - shouldPass := map[string]PaginationToken{ - "2": PaginationToken{ - Type: PaginationTokenTypeStream, - PDUPosition: 2, - }, - "s4": PaginationToken{ - Type: PaginationTokenTypeStream, - PDUPosition: 4, - }, - "s3_1": PaginationToken{ - Type: PaginationTokenTypeStream, - PDUPosition: 3, - EDUTypingPosition: 1, - }, - "t3_1_4": PaginationToken{ - Type: PaginationTokenTypeTopology, - PDUPosition: 3, - EDUTypingPosition: 1, - }, +func TestNewSyncTokenFromString(t *testing.T) { + shouldPass := map[string]syncToken{ + "s4_0": NewStreamToken(4, 0).syncToken, + "s3_1": NewStreamToken(3, 1).syncToken, + "t3_1": NewTopologyToken(3, 1).syncToken, } shouldFail := []string{ @@ -32,20 +17,21 @@ func TestNewPaginationTokenFromString(t *testing.T) { "b", "b-1", "-4", + "2", } for test, expected := range shouldPass { - result, err := NewPaginationTokenFromString(test) + result, err := newSyncTokenFromString(test) if err != nil { t.Error(err) } - if *result != expected { - t.Errorf("expected %v but got %v", expected.String(), result.String()) + if result.String() != expected.String() { + t.Errorf("%s expected %v but got %v", test, expected.String(), result.String()) } } for _, test := range shouldFail { - if _, err := NewPaginationTokenFromString(test); err == nil { + if _, err := newSyncTokenFromString(test); err == nil { t.Errorf("input '%v' should have errored but didn't", test) } } diff --git a/sytest-blacklist b/sytest-blacklist index caad25455..9f140ed1c 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -39,3 +39,12 @@ Ignore invite in incremental sync # Blacklisted because this test calls /r0/events which we don't implement New room members see their own join event Existing members see new members' join events + +# Blacklisted because the federation work for these hasn't been finished yet. +Can recv device messages over federation +Device messages over federation wake up /sync +Wildcard device messages over federation wake up /sync + +# We don't implement soft-failed events yet, but because the /send response is vague, +# this test thinks it's all fine... +Inbound federation accepts a second soft-failed event diff --git a/sytest-whitelist b/sytest-whitelist index 7bd2a63c4..0036d60ea 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -53,6 +53,8 @@ PUT /rooms/:room_id/send/:event_type/:txn_id deduplicates the same txn id GET /rooms/:room_id/state/m.room.power_levels can fetch levels PUT /rooms/:room_id/state/m.room.power_levels can set levels PUT power_levels should not explode if the old power levels were empty +GET /rooms/:room_id/state/m.room.member/:user_id?format=event fetches my membership event +GET /rooms/:room_id/joined_members fetches my membership Both GET and PUT work POST /rooms/:room_id/read_markers can create read marker User signups are forbidden from starting with '_' @@ -60,6 +62,7 @@ Request to logout with invalid an access token is rejected Request to logout without an access token is rejected Room creation reports m.room.create to myself Room creation reports m.room.member to myself +Invited user can see room metadata # Blacklisted because these tests call /r0/events which we don't implement # New room members see their own join event # Existing members see new members' join events @@ -113,6 +116,8 @@ User can invite local user to room with version 1 Should reject keys claiming to belong to a different user Can add account data Can add account data to room +Can get account data without syncing +Can get room account data without syncing #Latest account data appears in v2 /sync New account data appears in incremental v2 /sync Checking local federation server @@ -253,3 +258,108 @@ User can invite local user to room with version 3 User can invite local user to room with version 4 A pair of servers can establish a join in a v2 room Can logout all devices +State from remote users is included in the timeline in an incremental sync +User can invite remote user to room with version 1 +User can invite remote user to room with version 2 +User can invite remote user to room with version 3 +User can invite remote user to room with version 4 +User can create and send/receive messages in a room with version 5 +local user can join room with version 5 +User can invite local user to room with version 5 +remote user can join room with version 5 +User can invite remote user to room with version 5 +Remote user can backfill in a room with version 5 +Inbound federation can receive v1 /send_join +Inbound federation can get state for a room +Inbound federation of state requires event_id as a mandatory paramater +Inbound federation can get state_ids for a room +Inbound federation of state_ids requires event_id as a mandatory paramater +Federation rejects inbound events where the prev_events cannot be found +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state +Alternative server names do not cause a routing loop +Events whose auth_events are in the wrong room do not mess up the room state +Inbound federation can return events +Inbound federation can return missing events for world_readable visibility +Inbound federation can return missing events for invite visibility +Inbound federation can get public room list +An event which redacts itself should be ignored +A pair of events which redact each other should be ignored +Outbound federation can backfill events +Inbound federation can backfill events +Backfill checks the events requested belong to the room +Backfilled events whose prev_events are in a different room do not allow cross-room back-pagination +Outbound federation can request missing events +New room members see their own join event +Existing members see new members' join events +Inbound federation can receive events +Inbound federation can receive redacted events +Can logout current device +Can send a message directly to a device using PUT /sendToDevice +Can recv a device message using /sync +Can recv device messages until they are acknowledged +Device messages with the same txn_id are deduplicated +Device messages wake up /sync +# TODO: separate PR for: Can recv device messages over federation +# TODO: separate PR for: Device messages over federation wake up /sync +Can send messages with a wildcard device id +Can send messages with a wildcard device id to two devices +Wildcard device messages wake up /sync +# TODO: separate PR for: Wildcard device messages over federation wake up /sync +Can send a to-device message to two users which both receive it using /sync +User can create and send/receive messages in a room with version 6 +local user can join room with version 6 +User can invite local user to room with version 6 +remote user can join room with version 6 +User can invite remote user to room with version 6 +Remote user can backfill in a room with version 6 +Inbound: send_join rejects invalid JSON for room version 6 +Outbound federation rejects backfill containing invalid JSON for events in room version 6 +Invalid JSON integers +Invalid JSON special values +Invalid JSON floats +Outbound federation will ignore a missing event with bad JSON for room version 6 +Server correctly handles transactions that break edu limits +Server rejects invalid JSON in a version 6 room +Can download without a file name over federation +POST /media/r0/upload can create an upload +GET /media/r0/download can fetch the value again +Remote users can join room by alias +Alias creators can delete alias with no ops +Alias creators can delete canonical alias with no ops +Room members can override their displayname on a room-specific basis +displayname updates affect room member events +avatar_url updates affect room member events +Real non-joined users can get individual state for world_readable rooms after leaving +Can upload with Unicode file name +POSTed media can be thumbnailed +Remote media can be thumbnailed +Can download with Unicode file name locally +Can download file 'ascii' +Can download file 'name with spaces' +Can download file 'name;with;semicolons' +Can download specifying a different ASCII file name +Can download with Unicode file name over federation +Can download specifying a different Unicode file name +Inbound /v1/send_join rejects joins from other servers +Outbound federation can query v1 /send_join +Inbound /v1/send_join rejects incorrectly-signed joins +POST /rooms/:room_id/state/m.room.name sets name +GET /rooms/:room_id/state/m.room.name gets name +POST /rooms/:room_id/state/m.room.topic sets topic +GET /rooms/:room_id/state/m.room.topic gets topic +GET /rooms/:room_id/state fetches entire room state +Setting room topic reports m.room.topic to myself +setting 'm.room.name' respects room powerlevel +Syncing a new room with a large timeline limit isn't limited +Left rooms appear in the leave section of sync +Banned rooms appear in the leave section of sync +Getting state checks the events requested belong to the room +Getting state IDs checks the events requested belong to the room +Can invite users to invite-only rooms +Uninvited users cannot join the room +Invited user can reject invite +Invited user can reject invite for empty room +Invited user can reject local invite after originator leaves +Typing notification sent to local room members +Typing notifications also sent to remote room members +Typing can be explicitly stopped diff --git a/userapi/api/api.go b/userapi/api/api.go new file mode 100644 index 000000000..cf0f05633 --- /dev/null +++ b/userapi/api/api.go @@ -0,0 +1,193 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" +) + +// UserInternalAPI is the internal API for information about users and devices. +type UserInternalAPI interface { + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error + PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error + PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error + QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error + QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error + QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error + QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error +} + +// InputAccountDataRequest is the request for InputAccountData +type InputAccountDataRequest struct { + UserID string // required: the user to set account data for + RoomID string // optional: the room to associate the account data with + DataType string // required: the data type of the data + AccountData json.RawMessage // required: the message content +} + +// InputAccountDataResponse is the response for InputAccountData +type InputAccountDataResponse struct { +} + +// QueryAccessTokenRequest is the request for QueryAccessToken +type QueryAccessTokenRequest struct { + AccessToken string + // optional user ID, valid only if the token is an appservice. + // https://matrix.org/docs/spec/application_service/r0.1.2#using-sync-and-events + AppServiceUserID string +} + +// QueryAccessTokenResponse is the response for QueryAccessToken +type QueryAccessTokenResponse struct { + Device *Device + Err error // e.g ErrorForbidden +} + +// QueryAccountDataRequest is the request for QueryAccountData +type QueryAccountDataRequest struct { + UserID string // required: the user to get account data for. + RoomID string // optional: the room ID, or global account data if not specified. + DataType string // optional: the data type, or all types if not specified. +} + +// QueryAccountDataResponse is the response for QueryAccountData +type QueryAccountDataResponse struct { + GlobalAccountData map[string]json.RawMessage // type -> data + RoomAccountData map[string]map[string]json.RawMessage // room -> type -> data +} + +// QueryDevicesRequest is the request for QueryDevices +type QueryDevicesRequest struct { + UserID string +} + +// QueryDevicesResponse is the response for QueryDevices +type QueryDevicesResponse struct { + UserExists bool + Devices []Device +} + +// QueryProfileRequest is the request for QueryProfile +type QueryProfileRequest struct { + // The user ID to query + UserID string +} + +// QueryProfileResponse is the response for QueryProfile +type QueryProfileResponse struct { + // True if the user exists. Querying for a profile does not create them. + UserExists bool + // The current display name if set. + DisplayName string + // The current avatar URL if set. + AvatarURL string +} + +// PerformAccountCreationRequest is the request for PerformAccountCreation +type PerformAccountCreationRequest struct { + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + + AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. + Password string // optional: if missing then this account will be a passwordless account + OnConflict Conflict +} + +// PerformAccountCreationResponse is the response for PerformAccountCreation +type PerformAccountCreationResponse struct { + AccountCreated bool + Account *Account +} + +// PerformDeviceCreationRequest is the request for PerformDeviceCreation +type PerformDeviceCreationRequest struct { + Localpart string + AccessToken string // optional: if blank one will be made on your behalf + // optional: if nil an ID is generated for you. If set, replaces any existing device session, + // which will generate a new access token and invalidate the old one. + DeviceID *string + // optional: if nil no display name will be associated with this device. + DeviceDisplayName *string +} + +// PerformDeviceCreationResponse is the response for PerformDeviceCreation +type PerformDeviceCreationResponse struct { + DeviceCreated bool + Device *Device +} + +// Device represents a client's device (mobile, web, etc) +type Device struct { + ID string + UserID string + // The access_token granted to this device. + // This uniquely identifies the device from all other devices and clients. + AccessToken string + // The unique ID of the session identified by the access token. + // Can be used as a secure substitution in places where data needs to be + // associated with access tokens. + SessionID int64 + // TODO: display name, last used timestamp, keys, etc + DisplayName string +} + +// Account represents a Matrix account on this home server. +type Account struct { + UserID string + Localpart string + ServerName gomatrixserverlib.ServerName + AppServiceID string + // TODO: Other flags like IsAdmin, IsGuest + // TODO: Associations (e.g. with application services) +} + +// ErrorForbidden is an error indicating that the supplied access token is forbidden +type ErrorForbidden struct { + Message string +} + +func (e *ErrorForbidden) Error() string { + return "Forbidden: " + e.Message +} + +// ErrorConflict is an error indicating that there was a conflict which resulted in the request being aborted. +type ErrorConflict struct { + Message string +} + +func (e *ErrorConflict) Error() string { + return "Conflict: " + e.Message +} + +// Conflict is an enum representing what to do when encountering conflicting when creating profiles/devices +type Conflict int + +// AccountType is an enum representing the kind of account +type AccountType int + +const ( + // ConflictUpdate will update matching records returning no error + ConflictUpdate Conflict = 1 + // ConflictAbort will reject the request with ErrorConflict + ConflictAbort Conflict = 2 + + // AccountTypeUser indicates this is a user account + AccountTypeUser AccountType = 1 + // AccountTypeGuest indicates this is a guest account + AccountTypeGuest AccountType = 2 +) diff --git a/userapi/internal/api.go b/userapi/internal/api.go new file mode 100644 index 000000000..b081eca49 --- /dev/null +++ b/userapi/internal/api.go @@ -0,0 +1,237 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/appservice/types" + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/gomatrixserverlib" +) + +type UserInternalAPI struct { + AccountDB accounts.Database + DeviceDB devices.Database + ServerName gomatrixserverlib.ServerName + // AppServices is the list of all registered AS + AppServices []config.ApplicationService +} + +func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + } + if req.DataType == "" { + return fmt.Errorf("data type must not be empty") + } + return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) +} + +func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + if req.AccountType == api.AccountTypeGuest { + acc, err := a.AccountDB.CreateGuestAccount(ctx) + if err != nil { + return err + } + res.AccountCreated = true + res.Account = acc + return nil + } + acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID) + if err != nil { + if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists + switch req.OnConflict { + case api.ConflictUpdate: + break + case api.ConflictAbort: + return &api.ErrorConflict{ + Message: err.Error(), + } + } + } + // account already exists + res.AccountCreated = false + res.Account = &api.Account{ + AppServiceID: req.AppServiceID, + Localpart: req.Localpart, + ServerName: a.ServerName, + UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + } + return nil + } + res.AccountCreated = true + res.Account = acc + return nil +} +func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { + dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName) + if err != nil { + return err + } + res.DeviceCreated = true + res.Device = dev + return nil +} + +func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + } + prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err + } + res.UserExists = true + res.AvatarURL = prof.AvatarURL + res.DisplayName = prof.DisplayName + return nil +} + +func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) + } + devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local) + if err != nil { + return err + } + res.Devices = devs + return nil +} + +func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) + } + if req.DataType != "" { + var data json.RawMessage + data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + if err != nil { + return err + } + res.RoomAccountData = make(map[string]map[string]json.RawMessage) + res.GlobalAccountData = make(map[string]json.RawMessage) + if data != nil { + if req.RoomID != "" { + if _, ok := res.RoomAccountData[req.RoomID]; !ok { + res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage) + } + res.RoomAccountData[req.RoomID][req.DataType] = data + } else { + res.GlobalAccountData[req.DataType] = data + } + } + return nil + } + global, rooms, err := a.AccountDB.GetAccountData(ctx, local) + if err != nil { + return err + } + res.RoomAccountData = rooms + res.GlobalAccountData = global + return nil +} + +func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAccessTokenRequest, res *api.QueryAccessTokenResponse) error { + if req.AppServiceUserID != "" { + appServiceDevice, err := a.queryAppServiceToken(ctx, req.AccessToken, req.AppServiceUserID) + res.Device = appServiceDevice + res.Err = err + return nil + } + device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err + } + res.Device = device + return nil +} + +// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem +// creating a 'device'. +func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { + // Search for app service with given access_token + var appService *config.ApplicationService + for _, as := range a.AppServices { + if as.ASToken == token { + appService = &as + break + } + } + if appService == nil { + return nil, nil + } + + // Create a dummy device for AS user + dev := api.Device{ + // Use AS dummy device ID + ID: types.AppServiceDeviceID, + // AS dummy device has AS's token. + AccessToken: token, + } + + localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) + if err != nil { + return nil, err + } + + if localpart != "" { // AS is masquerading as another user + // Verify that the user is registered + account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart) + // Verify that account exists & appServiceID matches + if err == nil && account.AppServiceID == appService.ID { + // Set the userID of dummy device + dev.UserID = appServiceUserID + return &dev, nil + } + return nil, &api.ErrorForbidden{Message: "appservice has not registered this user"} + } + + // AS is not masquerading as any user, so use AS's sender_localpart + dev.UserID = appService.SenderLocalpart + return &dev, nil +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go new file mode 100644 index 000000000..4ab0d690e --- /dev/null +++ b/userapi/inthttp/client.go @@ -0,0 +1,130 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "context" + "errors" + "net/http" + + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/opentracing/opentracing-go" +) + +// HTTP paths for the internal HTTP APIs +const ( + InputAccountDataPath = "/userapi/inputAccountData" + + PerformDeviceCreationPath = "/userapi/performDeviceCreation" + PerformAccountCreationPath = "/userapi/performAccountCreation" + + QueryProfilePath = "/userapi/queryProfile" + QueryAccessTokenPath = "/userapi/queryAccessToken" + QueryDevicesPath = "/userapi/queryDevices" + QueryAccountDataPath = "/userapi/queryAccountData" +) + +// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. +// If httpClient is nil an error is returned +func NewUserAPIClient( + apiURL string, + httpClient *http.Client, +) (api.UserInternalAPI, error) { + if httpClient == nil { + return nil, errors.New("NewUserAPIClient: httpClient is ") + } + return &httpUserInternalAPI{ + apiURL: apiURL, + httpClient: httpClient, + }, nil +} + +type httpUserInternalAPI struct { + apiURL string + httpClient *http.Client +} + +func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData") + defer span.Finish() + + apiURL := h.apiURL + InputAccountDataPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) PerformAccountCreation( + ctx context.Context, + request *api.PerformAccountCreationRequest, + response *api.PerformAccountCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformAccountCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformDeviceCreation( + ctx context.Context, + request *api.PerformDeviceCreationRequest, + response *api.PerformDeviceCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformDeviceCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryProfile( + ctx context.Context, + request *api.QueryProfileRequest, + response *api.QueryProfileResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryProfile") + defer span.Finish() + + apiURL := h.apiURL + QueryProfilePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryAccessToken( + ctx context.Context, + request *api.QueryAccessTokenRequest, + response *api.QueryAccessTokenResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccessToken") + defer span.Finish() + + apiURL := h.apiURL + QueryAccessTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDevices") + defer span.Finish() + + apiURL := h.apiURL + QueryDevicesPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccountData") + defer span.Finish() + + apiURL := h.apiURL + QueryAccountDataPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go new file mode 100644 index 000000000..8f3be7738 --- /dev/null +++ b/userapi/inthttp/server.go @@ -0,0 +1,106 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { + internalAPIMux.Handle(PerformAccountCreationPath, + httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformAccountCreationRequest{} + response := api.PerformAccountCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformAccountCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformDeviceCreationPath, + httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformDeviceCreationRequest{} + response := api.PerformDeviceCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformDeviceCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryProfilePath, + httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { + request := api.QueryProfileRequest{} + response := api.QueryProfileResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryProfile(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryAccessTokenPath, + httputil.MakeInternalAPI("queryAccessToken", func(req *http.Request) util.JSONResponse { + request := api.QueryAccessTokenRequest{} + response := api.QueryAccessTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryAccessToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryDevicesPath, + httputil.MakeInternalAPI("queryDevices", func(req *http.Request) util.JSONResponse { + request := api.QueryDevicesRequest{} + response := api.QueryDevicesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryDevices(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryAccountDataPath, + httputil.MakeInternalAPI("queryAccountData", func(req *http.Request) util.JSONResponse { + request := api.QueryAccountDataRequest{} + response := api.QueryAccountDataResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryAccountData(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/clientapi/auth/storage/accounts/interface.go b/userapi/storage/accounts/interface.go similarity index 73% rename from clientapi/auth/storage/accounts/interface.go rename to userapi/storage/accounts/interface.go index a5052b047..c6692879b 100644 --- a/clientapi/auth/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -16,28 +16,37 @@ package accounts import ( "context" + "encoding/json" "errors" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - common.PartitionStorer - GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error) + internal.PartitionStorer + GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error - CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error) - CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) + // CreateAccount makes a new account with the given login name and password, and creates an empty profile + // for this account. If no password is supplied, the account will be a passwordless account. If the + // account already exists, it will return nil, ErrUserExists. + CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error) + CreateGuestAccount(ctx context.Context) (*api.Account, error) UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) - SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error - GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) - GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error) + SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error + GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) + // GetAccountDataByType returns account data matching a given + // localpart, room ID and type. + // If no account data could be found, returns nil + // Returns an error if there was an issue with the retrieval + GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) GetNewNumericLocalpart(ctx context.Context) (int64, error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) @@ -46,7 +55,7 @@ type Database interface { GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error) + GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/clientapi/auth/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go similarity index 77% rename from clientapi/auth/storage/accounts/postgres/account_data_table.go rename to userapi/storage/accounts/postgres/account_data_table.go index 9198a7440..90c79e878 100644 --- a/clientapi/auth/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -17,10 +17,9 @@ package postgres import ( "context" "database/sql" + "encoding/json" - "github.com/matrix-org/dendrite/common" - - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" ) const accountDataSchema = ` @@ -74,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := txn.Stmt(s.insertAccountDataStmt) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) @@ -84,18 +83,18 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, - err error, + /* global */ map[string]json.RawMessage, + /* rooms */ map[string]map[string]json.RawMessage, + error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) if err != nil { - return + return nil, nil, err } - defer common.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") - global = []gomatrixserverlib.ClientEvent{} - rooms = make(map[string][]gomatrixserverlib.ClientEvent) + global := map[string]json.RawMessage{} + rooms := map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string @@ -103,41 +102,33 @@ func (s *accountDataStatements) selectAccountData( var content []byte if err = rows.Scan(&roomID, &dataType, &content); err != nil { - return + return nil, nil, err } - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - - if len(roomID) > 0 { - rooms[roomID] = append(rooms[roomID], ac) + if roomID != "" { + if _, ok := rooms[roomID]; !ok { + rooms[roomID] = map[string]json.RawMessage{} + } + rooms[roomID][dataType] = content } else { - global = append(global, ac) + global[dataType] = content } } + return global, rooms, rows.Err() } func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { + var bytes []byte stmt := s.selectAccountDataByTypeStmt - var content []byte - - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } - return } - - data = &gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - + data = json.RawMessage(bytes) return } diff --git a/clientapi/auth/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go similarity index 96% rename from clientapi/auth/storage/accounts/postgres/accounts_table.go rename to userapi/storage/accounts/postgres/accounts_table.go index 85c1938a1..931ffb73d 100644 --- a/clientapi/auth/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -19,8 +19,8 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" @@ -92,7 +92,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // on success. func (s *accountsStatements) insertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := txn.Stmt(s.insertAccountStmt) @@ -106,7 +106,7 @@ func (s *accountsStatements) insertAccount( return nil, err } - return &authtypes.Account{ + return &api.Account{ Localpart: localpart, UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, @@ -123,9 +123,9 @@ func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectAccountByLocalpart( ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var appserviceIDPtr sql.NullString - var acc authtypes.Account + var acc api.Account stmt := s.selectAccountByLocalpartStmt err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) diff --git a/clientapi/auth/storage/accounts/postgres/filter_table.go b/userapi/storage/accounts/postgres/filter_table.go similarity index 100% rename from clientapi/auth/storage/accounts/postgres/filter_table.go rename to userapi/storage/accounts/postgres/filter_table.go diff --git a/clientapi/auth/storage/accounts/postgres/membership_table.go b/userapi/storage/accounts/postgres/membership_table.go similarity index 97% rename from clientapi/auth/storage/accounts/postgres/membership_table.go rename to userapi/storage/accounts/postgres/membership_table.go index 04e9095e9..623530acc 100644 --- a/clientapi/auth/storage/accounts/postgres/membership_table.go +++ b/userapi/storage/accounts/postgres/membership_table.go @@ -18,10 +18,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" - "github.com/lib/pq" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal" ) const membershipSchema = ` @@ -127,7 +126,7 @@ func (s *membershipStatements) selectMembershipsByLocalpart( memberships = []authtypes.Membership{} - defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed") for rows.Next() { var m authtypes.Membership m.Localpart = localpart diff --git a/clientapi/auth/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go similarity index 100% rename from clientapi/auth/storage/accounts/postgres/profile_table.go rename to userapi/storage/accounts/postgres/profile_table.go diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go similarity index 92% rename from clientapi/auth/storage/accounts/postgres/storage.go rename to userapi/storage/accounts/postgres/storage.go index 8ce367a3e..e55099800 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -17,12 +17,13 @@ package postgres import ( "context" "database/sql" + "encoding/json" "errors" "strconv" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" @@ -33,7 +34,7 @@ import ( // Database represents an account database type Database struct { db *sql.DB - common.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements memberships membershipStatements @@ -44,13 +45,13 @@ type Database struct { } // NewDatabase creates a new accounts and profiles database -func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (*Database, error) { var db *sql.DB var err error - if db, err = sqlutil.Open("postgres", dataSourceName); err != nil { + if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } - partitions := common.PartitionOffsetStatements{} + partitions := sqlutil.PartitionOffsetStatements{} if err = partitions.Prepare(db, "account"); err != nil { return nil, err } @@ -85,7 +86,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( ctx context.Context, localpart, plaintextPassword string, -) (*authtypes.Account, error) { +) (*api.Account, error) { hash, err := d.accounts.selectPasswordHash(ctx, localpart) if err != nil { return nil, err @@ -122,8 +123,8 @@ func (d *Database) SetDisplayName( // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) if err != nil { @@ -138,11 +139,11 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, nil. +// account already exists, it will return nil, sqlutil.ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, -) (acc *authtypes.Account, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { +) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err }) @@ -151,7 +152,7 @@ func (d *Database) CreateAccount( func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var err error // Generate a password hash if this is not a password-less user @@ -163,13 +164,13 @@ func (d *Database) createAccount( } } if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { - if common.IsUniqueConstraintViolationErr(err) { - return nil, nil + if sqlutil.IsUniqueConstraintViolationErr(err) { + return nil, sqlutil.ErrUserExists } return nil, err } - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -177,7 +178,7 @@ func (d *Database) createAccount( "sender": [], "underride": [] } - }`); err != nil { + }`)); err != nil { return nil, err } return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) @@ -210,7 +211,7 @@ func (d *Database) removeMembershipsByEventIDs( func (d *Database) UpdateMemberships( ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { return err } @@ -295,9 +296,9 @@ func (d *Database) newMembership( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) } @@ -306,8 +307,8 @@ func (d *Database) SaveAccountData( // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { return d.accountDatas.selectAccountData(ctx, localpart) @@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, ) @@ -348,7 +349,7 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { user, err := d.threepids.selectLocalpartForThreePID( ctx, txn, threepid, medium, ) @@ -428,6 +429,6 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } diff --git a/clientapi/auth/storage/accounts/postgres/threepid_table.go b/userapi/storage/accounts/postgres/threepid_table.go similarity index 95% rename from clientapi/auth/storage/accounts/postgres/threepid_table.go rename to userapi/storage/accounts/postgres/threepid_table.go index 851b4a90b..7de96350c 100644 --- a/clientapi/auth/storage/accounts/postgres/threepid_table.go +++ b/userapi/storage/accounts/postgres/threepid_table.go @@ -18,7 +18,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -82,7 +82,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { func (s *threepidStatements) selectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { - stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) if err == sql.ErrNoRows { return "", nil @@ -117,7 +117,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) insertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { - stmt := common.TxStmt(txn, s.insertThreePIDStmt) + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium, localpart) return } diff --git a/clientapi/auth/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go similarity index 80% rename from clientapi/auth/storage/accounts/sqlite3/account_data_table.go rename to userapi/storage/accounts/sqlite3/account_data_table.go index b6bb63617..d048dbd19 100644 --- a/clientapi/auth/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -17,8 +17,7 @@ package sqlite3 import ( "context" "database/sql" - - "github.com/matrix-org/gomatrixserverlib" + "encoding/json" ) const accountDataSchema = ` @@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return @@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, - err error, + /* global */ map[string]json.RawMessage, + /* rooms */ map[string]map[string]json.RawMessage, + error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) if err != nil { - return + return nil, nil, err } - global = []gomatrixserverlib.ClientEvent{} - rooms = make(map[string][]gomatrixserverlib.ClientEvent) + global := map[string]json.RawMessage{} + rooms := map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string @@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData( var content []byte if err = rows.Scan(&roomID, &dataType, &content); err != nil { - return + return nil, nil, err } - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - - if len(roomID) > 0 { - rooms[roomID] = append(rooms[roomID], ac) + if roomID != "" { + if _, ok := rooms[roomID]; !ok { + rooms[roomID] = map[string]json.RawMessage{} + } + rooms[roomID][dataType] = content } else { - global = append(global, ac) + global[dataType] = content } } - return + return global, rooms, nil } func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { + var bytes []byte stmt := s.selectAccountDataByTypeStmt - var content []byte - - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } - return } - - data = &gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - + data = json.RawMessage(bytes) return } diff --git a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go similarity index 96% rename from clientapi/auth/storage/accounts/sqlite3/accounts_table.go rename to userapi/storage/accounts/sqlite3/accounts_table.go index fd6a09cde..768f536dd 100644 --- a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -19,8 +19,8 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" @@ -90,7 +90,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // on success. func (s *accountsStatements) insertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt @@ -104,7 +104,7 @@ func (s *accountsStatements) insertAccount( return nil, err } - return &authtypes.Account{ + return &api.Account{ Localpart: localpart, UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, @@ -121,9 +121,9 @@ func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectAccountByLocalpart( ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var appserviceIDPtr sql.NullString - var acc authtypes.Account + var acc api.Account stmt := s.selectAccountByLocalpartStmt err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) diff --git a/roomserver/storage/sqlite3/prepare.go b/userapi/storage/accounts/sqlite3/constraint.go similarity index 50% rename from roomserver/storage/sqlite3/prepare.go rename to userapi/storage/accounts/sqlite3/constraint.go index 482dfa2b9..32f96c8e4 100644 --- a/roomserver/storage/sqlite3/prepare.go +++ b/userapi/storage/accounts/sqlite3/constraint.go @@ -1,5 +1,4 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,24 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build !wasm + package sqlite3 import ( - "database/sql" + "errors" + + "github.com/mattn/go-sqlite3" ) -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string -} - -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { - for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { - return - } - } - return +func isConstraintError(err error) bool { + return errors.Is(err, sqlite3.ErrConstraint) } diff --git a/userapi/storage/accounts/sqlite3/constraint_wasm.go b/userapi/storage/accounts/sqlite3/constraint_wasm.go new file mode 100644 index 000000000..0dd5b1fea --- /dev/null +++ b/userapi/storage/accounts/sqlite3/constraint_wasm.go @@ -0,0 +1,21 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build wasm + +package sqlite3 + +func isConstraintError(err error) bool { + return false +} diff --git a/clientapi/auth/storage/accounts/sqlite3/filter_table.go b/userapi/storage/accounts/sqlite3/filter_table.go similarity index 100% rename from clientapi/auth/storage/accounts/sqlite3/filter_table.go rename to userapi/storage/accounts/sqlite3/filter_table.go diff --git a/clientapi/auth/storage/accounts/sqlite3/membership_table.go b/userapi/storage/accounts/sqlite3/membership_table.go similarity index 95% rename from clientapi/auth/storage/accounts/sqlite3/membership_table.go rename to userapi/storage/accounts/sqlite3/membership_table.go index bd9838b6b..67958f27d 100644 --- a/clientapi/auth/storage/accounts/sqlite3/membership_table.go +++ b/userapi/storage/accounts/sqlite3/membership_table.go @@ -20,7 +20,8 @@ import ( "strings" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) const membershipSchema = ` @@ -95,7 +96,7 @@ func (s *membershipStatements) insertMembership( func (s *membershipStatements) deleteMembershipsByEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, ) (err error) { - sqlStr := strings.Replace(deleteMembershipsByEventIDsSQL, "($1)", common.QueryVariadic(len(eventIDs)), 1) + sqlStr := strings.Replace(deleteMembershipsByEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) iEventIDs := make([]interface{}, len(eventIDs)) for i, e := range eventIDs { iEventIDs[i] = e @@ -125,7 +126,7 @@ func (s *membershipStatements) selectMembershipsByLocalpart( memberships = []authtypes.Membership{} - defer common.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed") for rows.Next() { var m authtypes.Membership m.Localpart = localpart diff --git a/clientapi/auth/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go similarity index 100% rename from clientapi/auth/storage/accounts/sqlite3/profile_table.go rename to userapi/storage/accounts/sqlite3/profile_table.go diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go similarity index 92% rename from clientapi/auth/storage/accounts/sqlite3/storage.go rename to userapi/storage/accounts/sqlite3/storage.go index e190ba6c2..dbf6606c3 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -17,24 +17,23 @@ package sqlite3 import ( "context" "database/sql" + "encoding/json" "errors" "strconv" "sync" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" - - // Import the postgres database driver. - _ "github.com/mattn/go-sqlite3" + // Import the sqlite3 database driver. ) // Database represents an account database type Database struct { db *sql.DB - common.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements memberships membershipStatements @@ -50,10 +49,14 @@ type Database struct { func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { var db *sql.DB var err error - if db, err = sqlutil.Open(common.SQLiteDriverName(), dataSourceName); err != nil { + cs, err := sqlutil.ParseFileURI(dataSourceName) + if err != nil { return nil, err } - partitions := common.PartitionOffsetStatements{} + if db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { + return nil, err + } + partitions := sqlutil.PartitionOffsetStatements{} if err = partitions.Prepare(db, "account"); err != nil { return nil, err } @@ -88,7 +91,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( ctx context.Context, localpart, plaintextPassword string, -) (*authtypes.Account, error) { +) (*api.Account, error) { hash, err := d.accounts.selectPasswordHash(ctx, localpart) if err != nil { return nil, err @@ -125,8 +128,8 @@ func (d *Database) SetDisplayName( // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { // We need to lock so we sequentially create numeric localparts. If we don't, two calls to // this function will cause the same number to be selected and one will fail with 'database is locked' // when the first txn upgrades to a write txn. @@ -148,11 +151,11 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, nil. +// account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, -) (acc *authtypes.Account, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { +) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err }) @@ -161,7 +164,7 @@ func (d *Database) CreateAccount( func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var err error // Generate a password hash if this is not a password-less user hash := "" @@ -172,13 +175,13 @@ func (d *Database) createAccount( } } if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { - if common.IsUniqueConstraintViolationErr(err) { - return nil, nil + if isConstraintError(err) { + return nil, sqlutil.ErrUserExists } return nil, err } - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -186,7 +189,7 @@ func (d *Database) createAccount( "sender": [], "underride": [] } - }`); err != nil { + }`)); err != nil { return nil, err } return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) @@ -219,7 +222,7 @@ func (d *Database) removeMembershipsByEventIDs( func (d *Database) UpdateMemberships( ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { return err } @@ -304,9 +307,9 @@ func (d *Database) newMembership( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) } @@ -315,8 +318,8 @@ func (d *Database) SaveAccountData( // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { return d.accountDatas.selectAccountData(ctx, localpart) @@ -328,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, ) @@ -357,7 +360,7 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { user, err := d.threepids.selectLocalpartForThreePID( ctx, txn, threepid, medium, ) @@ -437,6 +440,6 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } diff --git a/clientapi/auth/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go similarity index 92% rename from clientapi/auth/storage/accounts/sqlite3/threepid_table.go rename to userapi/storage/accounts/sqlite3/threepid_table.go index 29ee4c3d0..0200dee7f 100644 --- a/clientapi/auth/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -18,7 +18,8 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -82,7 +83,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { func (s *threepidStatements) selectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { - stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) if err == sql.ErrNoRows { return "", nil @@ -97,7 +98,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( if err != nil { return } - defer common.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed") threepids = []authtypes.ThreePID{} for rows.Next() { @@ -117,7 +118,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) insertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { - stmt := common.TxStmt(txn, s.insertThreePIDStmt) + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium, localpart) return } diff --git a/clientapi/auth/storage/accounts/storage.go b/userapi/storage/accounts/storage.go similarity index 57% rename from clientapi/auth/storage/accounts/storage.go rename to userapi/storage/accounts/storage.go index c643a4d0a..87f626bf9 100644 --- a/clientapi/auth/storage/accounts/storage.go +++ b/userapi/storage/accounts/storage.go @@ -19,22 +19,25 @@ package accounts import ( "net/url" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" + "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) -func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) { +// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// and sets postgres connection parameters +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.NewDatabase(dataSourceName, serverName) + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) } switch uri.Scheme { case "postgres": - return postgres.NewDatabase(dataSourceName, serverName) + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) case "file": return sqlite3.NewDatabase(dataSourceName, serverName) default: - return postgres.NewDatabase(dataSourceName, serverName) + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) } } diff --git a/clientapi/auth/storage/accounts/storage_wasm.go b/userapi/storage/accounts/storage_wasm.go similarity index 79% rename from clientapi/auth/storage/accounts/storage_wasm.go rename to userapi/storage/accounts/storage_wasm.go index 828afc6b4..692567059 100644 --- a/clientapi/auth/storage/accounts/storage_wasm.go +++ b/userapi/storage/accounts/storage_wasm.go @@ -18,11 +18,16 @@ import ( "fmt" "net/url" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) -func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase( + dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam + serverName gomatrixserverlib.ServerName, +) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return nil, fmt.Errorf("Cannot use postgres implementation") diff --git a/clientapi/auth/storage/devices/interface.go b/userapi/storage/devices/interface.go similarity index 64% rename from clientapi/auth/storage/devices/interface.go rename to userapi/storage/devices/interface.go index 95291e4a7..4bdb57850 100644 --- a/clientapi/auth/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -17,14 +17,20 @@ package devices import ( "context" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/userapi/api" ) type Database interface { - GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error) - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error) + GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) + GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + // CreateDevice makes a new device associated with the given user ID localpart. + // If there is already a device with the same device ID for this user, that access token will be revoked + // and replaced with the given accessToken. If the given accessToken is already in use for another device, + // an error will be returned. + // If no device ID is given one is generated. + // Returns the device on success. + CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *api.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error diff --git a/clientapi/auth/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go similarity index 88% rename from clientapi/auth/storage/devices/postgres/devices_table.go rename to userapi/storage/devices/postgres/devices_table.go index ee5591706..1d036d1b3 100644 --- a/clientapi/auth/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -20,9 +20,10 @@ import ( "time" "github.com/lib/pq" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -134,14 +135,14 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN func (s *devicesStatements) insertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, -) (*authtypes.Device, error) { +) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - stmt := common.TxStmt(txn, s.insertDeviceStmt) + stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil { return nil, err } - return &authtypes.Device{ + return &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, s.serverName), AccessToken: accessToken, @@ -153,7 +154,7 @@ func (s *devicesStatements) insertDevice( func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { - stmt := common.TxStmt(txn, s.deleteDeviceStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) _, err := stmt.ExecContext(ctx, id, localpart) return err } @@ -163,7 +164,7 @@ func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { - stmt := common.TxStmt(txn, s.deleteDevicesStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) return err } @@ -173,7 +174,7 @@ func (s *devicesStatements) deleteDevices( func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { - stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) _, err := stmt.ExecContext(ctx, localpart) return err } @@ -181,15 +182,15 @@ func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) updateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { - stmt := common.TxStmt(txn, s.updateDeviceNameStmt) + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) return err } func (s *devicesStatements) selectDeviceByToken( ctx context.Context, accessToken string, -) (*authtypes.Device, error) { - var dev authtypes.Device +) (*api.Device, error) { + var dev api.Device var localpart string stmt := s.selectDeviceByTokenStmt err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) @@ -204,11 +205,10 @@ func (s *devicesStatements) selectDeviceByToken( // localpart and deviceID func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, -) (*authtypes.Device, error) { - var dev authtypes.Device - var created sql.NullInt64 +) (*api.Device, error) { + var dev api.Device stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) @@ -218,22 +218,29 @@ func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) selectDevicesByLocalpart( ctx context.Context, localpart string, -) ([]authtypes.Device, error) { - devices := []authtypes.Device{} +) ([]api.Device, error) { + devices := []api.Device{} rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) if err != nil { return devices, err } - defer common.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed") for rows.Next() { - var dev authtypes.Device - err = rows.Scan(&dev.ID, &dev.DisplayName) + var dev api.Device + var id, displayname sql.NullString + err = rows.Scan(&id, &displayname) if err != nil { return devices, err } + if id.Valid { + dev.ID = id.String + } + if displayname.Valid { + dev.DisplayName = displayname.String + } dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } diff --git a/clientapi/auth/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go similarity index 87% rename from clientapi/auth/storage/devices/postgres/storage.go rename to userapi/storage/devices/postgres/storage.go index 3f613cf32..801657bd5 100644 --- a/clientapi/auth/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -20,9 +20,8 @@ import ( "database/sql" "encoding/base64" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -36,10 +35,10 @@ type Database struct { } // NewDatabase creates a new device database -func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (*Database, error) { var db *sql.DB var err error - if db, err = sqlutil.Open("postgres", dataSourceName); err != nil { + if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } d := devicesStatements{} @@ -53,7 +52,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByAccessToken( ctx context.Context, token string, -) (*authtypes.Device, error) { +) (*api.Device, error) { return d.devices.selectDeviceByToken(ctx, token) } @@ -61,14 +60,14 @@ func (d *Database) GetDeviceByAccessToken( // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( ctx context.Context, localpart, deviceID string, -) (*authtypes.Device, error) { +) (*api.Device, error) { return d.devices.selectDeviceByID(ctx, localpart, deviceID) } // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, -) ([]authtypes.Device, error) { +) ([]api.Device, error) { return d.devices.selectDevicesByLocalpart(ctx, localpart) } @@ -81,9 +80,9 @@ func (d *Database) GetDevicesByLocalpart( func (d *Database) CreateDevice( ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, -) (dev *authtypes.Device, returnErr error) { +) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { @@ -103,7 +102,7 @@ func (d *Database) CreateDevice( return } - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) return err @@ -133,7 +132,7 @@ func generateDeviceID() (string, error) { func (d *Database) UpdateDevice( ctx context.Context, localpart, deviceID string, displayName *string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) }) } @@ -145,7 +144,7 @@ func (d *Database) UpdateDevice( func (d *Database) RemoveDevice( ctx context.Context, deviceID, localpart string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { return err } @@ -160,7 +159,7 @@ func (d *Database) RemoveDevice( func (d *Database) RemoveDevices( ctx context.Context, localpart string, devices []string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { return err } @@ -174,7 +173,7 @@ func (d *Database) RemoveDevices( func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } diff --git a/clientapi/auth/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go similarity index 87% rename from clientapi/auth/storage/devices/sqlite3/devices_table.go rename to userapi/storage/devices/sqlite3/devices_table.go index f69810b7d..07ea5dca3 100644 --- a/clientapi/auth/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -20,9 +20,9 @@ import ( "strings" "time" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -125,11 +125,11 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN func (s *devicesStatements) insertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, -) (*authtypes.Device, error) { +) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - countStmt := common.TxStmt(txn, s.selectDevicesCountStmt) - insertStmt := common.TxStmt(txn, s.insertDeviceStmt) + countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { return nil, err } @@ -137,7 +137,7 @@ func (s *devicesStatements) insertDevice( if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { return nil, err } - return &authtypes.Device{ + return &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, s.serverName), AccessToken: accessToken, @@ -148,7 +148,7 @@ func (s *devicesStatements) insertDevice( func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { - stmt := common.TxStmt(txn, s.deleteDeviceStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) _, err := stmt.ExecContext(ctx, id, localpart) return err } @@ -156,12 +156,12 @@ func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1) + orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1) prep, err := s.db.Prepare(orig) if err != nil { return err } - stmt := common.TxStmt(txn, prep) + stmt := sqlutil.TxStmt(txn, prep) params := make([]interface{}, len(devices)+1) params[0] = localpart for i, v := range devices { @@ -175,7 +175,7 @@ func (s *devicesStatements) deleteDevices( func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { - stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) _, err := stmt.ExecContext(ctx, localpart) return err } @@ -183,15 +183,15 @@ func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) updateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { - stmt := common.TxStmt(txn, s.updateDeviceNameStmt) + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) return err } func (s *devicesStatements) selectDeviceByToken( ctx context.Context, accessToken string, -) (*authtypes.Device, error) { - var dev authtypes.Device +) (*api.Device, error) { + var dev api.Device var localpart string stmt := s.selectDeviceByTokenStmt err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) @@ -206,11 +206,10 @@ func (s *devicesStatements) selectDeviceByToken( // localpart and deviceID func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, -) (*authtypes.Device, error) { - var dev authtypes.Device - var created sql.NullInt64 +) (*api.Device, error) { + var dev api.Device stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) @@ -220,8 +219,8 @@ func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) selectDevicesByLocalpart( ctx context.Context, localpart string, -) ([]authtypes.Device, error) { - devices := []authtypes.Device{} +) ([]api.Device, error) { + devices := []api.Device{} rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) @@ -230,11 +229,18 @@ func (s *devicesStatements) selectDevicesByLocalpart( } for rows.Next() { - var dev authtypes.Device - err = rows.Scan(&dev.ID, &dev.DisplayName) + var dev api.Device + var id, displayname sql.NullString + err = rows.Scan(&id, &displayname) if err != nil { return devices, err } + if id.Valid { + dev.ID = id.String + } + if displayname.Valid { + dev.DisplayName = displayname.String + } dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } diff --git a/clientapi/auth/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go similarity index 88% rename from clientapi/auth/storage/devices/sqlite3/storage.go rename to userapi/storage/devices/sqlite3/storage.go index 85a8def2c..f248abda4 100644 --- a/clientapi/auth/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -20,9 +20,8 @@ import ( "database/sql" "encoding/base64" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" _ "github.com/mattn/go-sqlite3" @@ -41,7 +40,11 @@ type Database struct { func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { var db *sql.DB var err error - if db, err = sqlutil.Open(common.SQLiteDriverName(), dataSourceName); err != nil { + cs, err := sqlutil.ParseFileURI(dataSourceName) + if err != nil { + return nil, err + } + if db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } d := devicesStatements{} @@ -55,7 +58,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByAccessToken( ctx context.Context, token string, -) (*authtypes.Device, error) { +) (*api.Device, error) { return d.devices.selectDeviceByToken(ctx, token) } @@ -63,14 +66,14 @@ func (d *Database) GetDeviceByAccessToken( // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( ctx context.Context, localpart, deviceID string, -) (*authtypes.Device, error) { +) (*api.Device, error) { return d.devices.selectDeviceByID(ctx, localpart, deviceID) } // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, -) ([]authtypes.Device, error) { +) ([]api.Device, error) { return d.devices.selectDevicesByLocalpart(ctx, localpart) } @@ -83,9 +86,9 @@ func (d *Database) GetDevicesByLocalpart( func (d *Database) CreateDevice( ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, -) (dev *authtypes.Device, returnErr error) { +) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { @@ -105,7 +108,7 @@ func (d *Database) CreateDevice( return } - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) return err @@ -135,7 +138,7 @@ func generateDeviceID() (string, error) { func (d *Database) UpdateDevice( ctx context.Context, localpart, deviceID string, displayName *string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) }) } @@ -147,7 +150,7 @@ func (d *Database) UpdateDevice( func (d *Database) RemoveDevice( ctx context.Context, deviceID, localpart string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { return err } @@ -162,7 +165,7 @@ func (d *Database) RemoveDevice( func (d *Database) RemoveDevices( ctx context.Context, localpart string, devices []string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { return err } @@ -176,7 +179,7 @@ func (d *Database) RemoveDevices( func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, ) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } diff --git a/clientapi/auth/storage/devices/storage.go b/userapi/storage/devices/storage.go similarity index 57% rename from clientapi/auth/storage/devices/storage.go rename to userapi/storage/devices/storage.go index 99211db28..e094d202a 100644 --- a/clientapi/auth/storage/devices/storage.go +++ b/userapi/storage/devices/storage.go @@ -19,22 +19,25 @@ package devices import ( "net/url" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" + "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) -func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) { +// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// and sets postgres connection parameters +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.NewDatabase(dataSourceName, serverName) + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) } switch uri.Scheme { case "postgres": - return postgres.NewDatabase(dataSourceName, serverName) + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) case "file": return sqlite3.NewDatabase(dataSourceName, serverName) default: - return postgres.NewDatabase(dataSourceName, serverName) + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) } } diff --git a/clientapi/auth/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go similarity index 79% rename from clientapi/auth/storage/devices/storage_wasm.go rename to userapi/storage/devices/storage_wasm.go index 322852888..a5a515eff 100644 --- a/clientapi/auth/storage/devices/storage_wasm.go +++ b/userapi/storage/devices/storage_wasm.go @@ -18,11 +18,16 @@ import ( "fmt" "net/url" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) -func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase( + dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam + serverName gomatrixserverlib.ServerName, +) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return nil, fmt.Errorf("Cannot use postgres implementation") diff --git a/userapi/userapi.go b/userapi/userapi.go new file mode 100644 index 000000000..7aadec06a --- /dev/null +++ b/userapi/userapi.go @@ -0,0 +1,45 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userapi + +import ( + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/gomatrixserverlib" +) + +// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions +// on the given input API. +func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { + inthttp.AddRoutes(router, intAPI) +} + +// NewInternalAPI returns a concerete implementation of the internal API. Callers +// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. +func NewInternalAPI(accountDB accounts.Database, deviceDB devices.Database, + serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService) api.UserInternalAPI { + + return &internal.UserInternalAPI{ + AccountDB: accountDB, + DeviceDB: deviceDB, + ServerName: serverName, + AppServices: appServices, + } +} diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go new file mode 100644 index 000000000..163b10ec7 --- /dev/null +++ b/userapi/userapi_test.go @@ -0,0 +1,112 @@ +package userapi_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + serverName = gomatrixserverlib.ServerName("example.com") +) + +func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database, devices.Database) { + accountDB, err := accounts.NewDatabase("file::memory:", nil, serverName) + if err != nil { + t.Fatalf("failed to create account DB: %s", err) + } + deviceDB, err := devices.NewDatabase("file::memory:", nil, serverName) + if err != nil { + t.Fatalf("failed to create device DB: %s", err) + } + + return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil), accountDB, deviceDB +} + +func TestQueryProfile(t *testing.T) { + aliceAvatarURL := "mxc://example.com/alice" + aliceDisplayName := "Alice" + userAPI, accountDB, _ := MustMakeInternalAPI(t) + _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") + if err != nil { + t.Fatalf("failed to make account: %s", err) + } + if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + t.Fatalf("failed to set avatar url: %s", err) + } + if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + t.Fatalf("failed to set display name: %s", err) + } + + testCases := []struct { + req api.QueryProfileRequest + wantRes api.QueryProfileResponse + wantErr error + }{ + { + req: api.QueryProfileRequest{ + UserID: fmt.Sprintf("@alice:%s", serverName), + }, + wantRes: api.QueryProfileResponse{ + UserExists: true, + AvatarURL: aliceAvatarURL, + DisplayName: aliceDisplayName, + }, + }, + { + req: api.QueryProfileRequest{ + UserID: fmt.Sprintf("@bob:%s", serverName), + }, + wantRes: api.QueryProfileResponse{ + UserExists: false, + }, + }, + { + req: api.QueryProfileRequest{ + UserID: "@alice:wrongdomain.com", + }, + wantErr: fmt.Errorf("wrong domain"), + }, + } + + runCases := func(testAPI api.UserInternalAPI) { + for _, tc := range testCases { + var gotRes api.QueryProfileResponse + gotErr := testAPI.QueryProfile(context.TODO(), &tc.req, &gotRes) + if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil { + t.Errorf("QueryProfile error, got %s want %s", gotErr, tc.wantErr) + continue + } + if !reflect.DeepEqual(tc.wantRes, gotRes) { + t.Errorf("QueryProfile response got %+v want %+v", gotRes, tc.wantRes) + } + } + } + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, userAPI) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) + if err != nil { + t.Fatalf("failed to create HTTP client") + } + runCases(httpAPI) + }) + t.Run("Monolith", func(t *testing.T) { + runCases(userAPI) + }) +}