diff --git a/appservice/api/query.go b/appservice/api/query.go index cf25a9616..4d1cf9474 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -26,6 +26,23 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +// AppServiceInternalAPI is used to query user and room alias data from application +// services +type AppServiceInternalAPI interface { + // Check whether a room alias exists within any application service namespaces + RoomAliasExists( + ctx context.Context, + req *RoomAliasExistsRequest, + resp *RoomAliasExistsResponse, + ) error + // Check whether a user ID exists within any application service namespaces + UserIDExists( + ctx context.Context, + req *UserIDExistsRequest, + resp *UserIDExistsResponse, + ) error +} + // RoomAliasExistsRequest is a request to an application service // about whether a room alias exists type RoomAliasExistsRequest struct { @@ -60,31 +77,14 @@ type UserIDExistsResponse struct { UserIDExists bool `json:"exists"` } -// AppServiceQueryAPI is used to query user and room alias data from application -// services -type AppServiceQueryAPI interface { - // Check whether a room alias exists within any application service namespaces - RoomAliasExists( - ctx context.Context, - req *RoomAliasExistsRequest, - resp *RoomAliasExistsResponse, - ) error - // Check whether a user ID exists within any application service namespaces - UserIDExists( - ctx context.Context, - req *UserIDExistsRequest, - resp *UserIDExistsResponse, - ) error -} - // RetrieveUserProfile is a wrapper that queries both the local database and // application services for a given user's profile // TODO: Remove this, it's called from federationapi and clientapi but is a pure function func RetrieveUserProfile( ctx context.Context, userID string, - asAPI AppServiceQueryAPI, - profileAPI userapi.UserProfileAPI, + asAPI AppServiceInternalAPI, + profileAPI userapi.ClientUserAPI, ) (*authtypes.Profile, error) { localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { diff --git a/appservice/appservice.go b/appservice/appservice.go index b99091866..bd292767b 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -16,8 +16,6 @@ package appservice import ( "context" - "crypto/tls" - "net/http" "sync" "time" @@ -36,10 +34,11 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) // AddInternalRoutes registers HTTP handlers for internal API calls -func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceQueryAPI) { +func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceInternalAPI) { inthttp.AddRoutes(queryAPI, router) } @@ -47,22 +46,19 @@ func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceQuer // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, - userAPI userapi.UserInternalAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, -) appserviceAPI.AppServiceQueryAPI { - client := &http.Client{ - Timeout: time.Second * 30, - Transport: &http.Transport{ - DisableKeepAlives: true, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: base.Cfg.AppServiceAPI.DisableTLSValidation, - }, - }, - } + userAPI userapi.AppserviceUserAPI, + rsAPI roomserverAPI.AppserviceRoomserverAPI, +) appserviceAPI.AppServiceInternalAPI { + client := gomatrixserverlib.NewClient( + gomatrixserverlib.WithTimeout(time.Second*30), + gomatrixserverlib.WithKeepAlives(false), + gomatrixserverlib.WithSkipVerify(base.Cfg.AppServiceAPI.DisableTLSValidation), + ) + js, _ := jetstream.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) // Create a connection to the appservice postgres DB - appserviceDB, err := storage.NewDatabase(&base.Cfg.AppServiceAPI.Database) + appserviceDB, err := storage.NewDatabase(base, &base.Cfg.AppServiceAPI.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to appservice db") } @@ -117,7 +113,7 @@ func NewInternalAPI( // `sender_localpart` field of each application service if it doesn't // exist already func generateAppServiceAccount( - userAPI userapi.UserInternalAPI, + userAPI userapi.AppserviceUserAPI, as config.ApplicationService, ) error { var accRes userapi.PerformAccountCreationResponse diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 31e05caa0..e406e88a7 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -37,7 +37,7 @@ type OutputRoomEventConsumer struct { durable string topic string asDB storage.Database - rsAPI api.RoomserverInternalAPI + rsAPI api.AppserviceRoomserverAPI serverName string workerStates []types.ApplicationServiceWorkerState } @@ -49,7 +49,7 @@ func NewOutputRoomEventConsumer( cfg *config.Dendrite, js nats.JetStreamContext, appserviceDB storage.Database, - rsAPI api.RoomserverInternalAPI, + rsAPI api.AppserviceRoomserverAPI, workerStates []types.ApplicationServiceWorkerState, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ diff --git a/appservice/inthttp/client.go b/appservice/inthttp/client.go index 7e3cb208f..0a8baea99 100644 --- a/appservice/inthttp/client.go +++ b/appservice/inthttp/client.go @@ -29,7 +29,7 @@ type httpAppServiceQueryAPI struct { func NewAppserviceClient( appserviceURL string, httpClient *http.Client, -) (api.AppServiceQueryAPI, error) { +) (api.AppServiceInternalAPI, error) { if httpClient == nil { return nil, errors.New("NewRoomserverAliasAPIHTTP: httpClient is ") } diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go index 009b7b5db..645b43871 100644 --- a/appservice/inthttp/server.go +++ b/appservice/inthttp/server.go @@ -11,7 +11,7 @@ import ( ) // AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux. -func AddRoutes(a api.AppServiceQueryAPI, internalAPIMux *mux.Router) { +func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( AppServiceRoomAliasExistsPath, httputil.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse { diff --git a/appservice/query/query.go b/appservice/query/query.go index dacd3caa8..b7b0b335a 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" opentracing "github.com/opentracing/opentracing-go" log "github.com/sirupsen/logrus" ) @@ -32,7 +33,7 @@ const userIDExistsPath = "/users/" // AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI type AppServiceQueryAPI struct { - HTTPClient *http.Client + HTTPClient *gomatrixserverlib.Client Cfg *config.Dendrite } @@ -64,9 +65,8 @@ func (a *AppServiceQueryAPI) RoomAliasExists( if err != nil { return err } - req = req.WithContext(ctx) - resp, err := a.HTTPClient.Do(req) + resp, err := a.HTTPClient.DoHTTPRequest(ctx, req) if resp != nil { defer func() { err = resp.Body.Close() @@ -130,7 +130,7 @@ func (a *AppServiceQueryAPI) UserIDExists( if err != nil { return err } - resp, err := a.HTTPClient.Do(req.WithContext(ctx)) + resp, err := a.HTTPClient.DoHTTPRequest(ctx, req) if resp != nil { defer func() { err = resp.Body.Close() diff --git a/appservice/storage/postgres/storage.go b/appservice/storage/postgres/storage.go index eaf947ff3..a4c04b2cc 100644 --- a/appservice/storage/postgres/storage.go +++ b/appservice/storage/postgres/storage.go @@ -22,6 +22,7 @@ import ( // Import postgres database driver _ "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -35,13 +36,12 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) { var result Database var err error - if result.db, err = sqlutil.Open(dbProperties); err != nil { + if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { return nil, err } - result.writer = sqlutil.NewDummyWriter() if err = result.prepare(); err != nil { return nil, err } diff --git a/appservice/storage/sqlite3/storage.go b/appservice/storage/sqlite3/storage.go index 9260c7fe7..ad62b3628 100644 --- a/appservice/storage/sqlite3/storage.go +++ b/appservice/storage/sqlite3/storage.go @@ -21,6 +21,7 @@ import ( // Import SQLite database driver "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -34,13 +35,12 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) { var result Database var err error - if result.db, err = sqlutil.Open(dbProperties); err != nil { + if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } - result.writer = sqlutil.NewExclusiveWriter() if err = result.prepare(); err != nil { return nil, err } diff --git a/appservice/storage/storage.go b/appservice/storage/storage.go index 97b8501e2..89d5e0cc2 100644 --- a/appservice/storage/storage.go +++ b/appservice/storage/storage.go @@ -22,17 +22,18 @@ import ( "github.com/matrix-org/dendrite/appservice/storage/postgres" "github.com/matrix-org/dendrite/appservice/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets DB connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) + return sqlite3.NewDatabase(base, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties) + return postgres.NewDatabase(base, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/appservice/storage/storage_wasm.go b/appservice/storage/storage_wasm.go index 07d0e9ee1..230254598 100644 --- a/appservice/storage/storage_wasm.go +++ b/appservice/storage/storage_wasm.go @@ -18,13 +18,14 @@ import ( "fmt" "github.com/matrix-org/dendrite/appservice/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) -func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) + return sqlite3.NewDatabase(base, dbProperties) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/appservice/workers/transaction_scheduler.go b/appservice/workers/transaction_scheduler.go index 4dab00bd7..47d447c2c 100644 --- a/appservice/workers/transaction_scheduler.go +++ b/appservice/workers/transaction_scheduler.go @@ -42,7 +42,7 @@ var ( // size), then send that off to the AS's /transactions/{txnID} endpoint. It also // handles exponentially backing off in case the AS isn't currently available. func SetupTransactionWorkers( - client *http.Client, + client *gomatrixserverlib.Client, appserviceDB storage.Database, workerStates []types.ApplicationServiceWorkerState, ) error { @@ -58,7 +58,7 @@ func SetupTransactionWorkers( // worker is a goroutine that sends any queued events to the application service // it is given. -func worker(client *http.Client, db storage.Database, ws types.ApplicationServiceWorkerState) { +func worker(client *gomatrixserverlib.Client, db storage.Database, ws types.ApplicationServiceWorkerState) { log.WithFields(log.Fields{ "appservice": ws.AppService.ID, }).Info("Starting application service") @@ -200,7 +200,7 @@ func createTransaction( // send sends events to an application service. Returns an error if an OK was not // received back from the application service or the request timed out. func send( - client *http.Client, + client *gomatrixserverlib.Client, appservice config.ApplicationService, txnID int, transaction []byte, @@ -213,7 +213,7 @@ func send( return err } req.Header.Set("Content-Type", "application/json") - resp, err := client.Do(req) + resp, err := client.DoHTTPRequest(context.TODO(), req) if err != nil { return err } diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index d047f3fff..310ac7dda 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -268,7 +268,6 @@ func (m *DendriteMonolith) Start() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() federation := conn.CreateFederationClient(base, m.PineconeQUIC) serverKeyAPI := &signing.YggdrasilKeys{} @@ -281,7 +280,7 @@ func (m *DendriteMonolith) Start() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) asAPI := appservice.NewInternalAPI(base, m.userAPI, rsAPI) @@ -295,7 +294,6 @@ func (m *DendriteMonolith) Start() { monolith := setup.Monolith{ Config: base.Cfg, - AccountDB: accountDB, Client: conn.CreateClient(base, m.PineconeQUIC), FedClient: federation, KeyRing: keyRing, @@ -308,16 +306,7 @@ func (m *DendriteMonolith) Start() { ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } - monolith.AddAllPublicRoutes( - base.ProcessContext, - base.PublicClientAPIMux, - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - base.PublicMediaAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, - ) + monolith.AddAllPublicRoutes(base) httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 4e95e3972..991bc462f 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -107,7 +107,6 @@ func (m *DendriteMonolith) Start() { m.processContext = base.ProcessContext defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() federation := ygg.CreateFederationClient(base) serverKeyAPI := &signing.YggdrasilKeys{} @@ -120,7 +119,7 @@ func (m *DendriteMonolith) Start() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) @@ -132,7 +131,6 @@ func (m *DendriteMonolith) Start() { monolith := setup.Monolith{ Config: base.Cfg, - AccountDB: accountDB, Client: ygg.CreateClient(base), FedClient: federation, KeyRing: keyRing, @@ -146,16 +144,7 @@ func (m *DendriteMonolith) Start() { ygg, fsAPI, federation, ), } - monolith.AddAllPublicRoutes( - base.ProcessContext, - base.PublicClientAPIMux, - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - base.PublicMediaAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, - ) + monolith.AddAllPublicRoutes(base) httpRouter := mux.NewRouter() httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index 575c5377f..93345f4b9 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -51,7 +51,7 @@ type AccountDatabase interface { // Note: For an AS user, AS dummy device is returned. // On failure returns an JSON error response which can be sent to the client. func VerifyUserFromRequest( - req *http.Request, userAPI api.UserInternalAPI, + req *http.Request, userAPI api.QueryAcccessTokenAPI, ) (*api.Device, *util.JSONResponse) { // Try to find the Application Service user token, err := ExtractAccessToken(req) diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go index 020731c9f..5f51c662a 100644 --- a/clientapi/auth/login.go +++ b/clientapi/auth/login.go @@ -33,7 +33,7 @@ import ( // called after authorization has completed, with the result of the authorization. // If the final return value is non-nil, an error occurred and the cleanup function // is nil. -func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserAccountAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { +func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserLoginAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { reqBytes, err := ioutil.ReadAll(r) if err != nil { err := &util.JSONResponse{ diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index d401469c1..5085f0170 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -160,7 +160,6 @@ func TestBadLoginFromJSONReader(t *testing.T) { type fakeUserInternalAPI struct { UserInternalAPIForLogin - uapi.UserAccountAPI DeletedTokens []string } @@ -179,6 +178,10 @@ func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, re return nil } +func (ua *fakeUserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *uapi.PerformLoginTokenCreationRequest, res *uapi.PerformLoginTokenCreationResponse) error { + return nil +} + func (*fakeUserInternalAPI) QueryLoginToken(ctx context.Context, req *uapi.QueryLoginTokenRequest, res *uapi.QueryLoginTokenResponse) error { if req.Token == "invalidtoken" { return nil diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 22c430f97..6caf7dcdc 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -110,7 +110,7 @@ type UserInteractive struct { Sessions map[string][]string } -func NewUserInteractive(userAccountAPI api.UserAccountAPI, cfg *config.ClientAPI) *UserInteractive { +func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ GetAccountByPassword: userAccountAPI.QueryAccountByPassword, Config: cfg, diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index a4b4587a3..262e48103 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -24,9 +24,7 @@ var ( } ) -type fakeAccountDatabase struct { - api.UserAccountAPI -} +type fakeAccountDatabase struct{} func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { return nil diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index ad277056c..c1e86114b 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,7 +15,6 @@ package clientapi import ( - "github.com/gorilla/mux" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/producers" @@ -24,32 +23,28 @@ import ( "github.com/matrix-org/dendrite/internal/transactions" keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. func AddPublicRoutes( - process *process.ProcessContext, - router *mux.Router, - synapseAdminRouter *mux.Router, - dendriteAdminRouter *mux.Router, - cfg *config.ClientAPI, + base *base.BaseDendrite, federation *gomatrixserverlib.FederationClient, - rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, + asAPI appserviceAPI.AppServiceInternalAPI, transactionsCache *transactions.Cache, - fsAPI federationAPI.FederationInternalAPI, - userAPI userapi.UserInternalAPI, - userDirectoryProvider userapi.UserDirectoryProvider, - keyAPI keyserverAPI.KeyInternalAPI, + fsAPI federationAPI.ClientFederationAPI, + userAPI userapi.ClientUserAPI, + userDirectoryProvider userapi.QuerySearchProfilesAPI, + keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, - mscCfg *config.MSCs, ) { - js, natsClient := jetstream.Prepare(process, &cfg.Matrix.JetStream) + cfg := &base.Cfg.ClientAPI + mscCfg := &base.Cfg.MSCs + js, natsClient := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) syncProducer := &producers.SyncAPIProducer{ JetStream: js, @@ -63,7 +58,9 @@ func AddPublicRoutes( } routing.Setup( - router, synapseAdminRouter, dendriteAdminRouter, + base.PublicClientAPIMux, + base.SynapseAdminMux, + base.DendriteAdminMux, cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, syncProducer, transactionsCache, fsAPI, keyAPI, diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index 187e3412d..48b1ae88d 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -38,7 +38,7 @@ type SyncAPIProducer struct { TopicPresenceEvent string JetStream nats.JetStreamContext ServerName gomatrixserverlib.ServerName - UserAPI userapi.UserInternalAPI + UserAPI userapi.ClientUserAPI } // SendData sends account data to the sync API server diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index d0dd3ab8d..a5a3014ab 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -33,7 +33,7 @@ import ( // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} func GetAccountData( - req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userAPI api.ClientUserAPI, device *api.Device, userID string, roomID string, dataType string, ) util.JSONResponse { if userID != device.UserID { @@ -76,7 +76,7 @@ func GetAccountData( // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} func SaveAccountData( - req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userAPI api.ClientUserAPI, device *api.Device, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { if userID != device.UserID { @@ -152,7 +152,7 @@ type fullyReadEvent struct { // SaveReadMarker implements POST /rooms/{roomId}/read_markers func SaveReadMarker( req *http.Request, - userAPI api.UserInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, + userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, ) util.JSONResponse { // Verify that the user is a member of this room diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 31e431c78..125b3847d 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/util" ) -func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserverAPI.RoomserverInternalAPI) util.JSONResponse { +func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { if device.AccountType != userapi.AccountTypeAdmin { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/clientapi/routing/admin_whois.go b/clientapi/routing/admin_whois.go index 87bb79366..f1cbd3467 100644 --- a/clientapi/routing/admin_whois.go +++ b/clientapi/routing/admin_whois.go @@ -44,7 +44,7 @@ type connectionInfo struct { // GetAdminWhois implements GET /admin/whois/{userId} func GetAdminWhois( - req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userAPI api.ClientUserAPI, device *api.Device, userID string, ) util.JSONResponse { allowed := device.AccountType == api.AccountTypeAdmin || userID == device.UserID diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go index 8c4830532..504d60265 100644 --- a/clientapi/routing/aliases.go +++ b/clientapi/routing/aliases.go @@ -28,7 +28,7 @@ import ( // GetAliases implements GET /_matrix/client/r0/rooms/{roomId}/aliases func GetAliases( - req *http.Request, rsAPI api.RoomserverInternalAPI, device *userapi.Device, roomID string, + req *http.Request, rsAPI api.ClientRoomserverAPI, device *userapi.Device, roomID string, ) util.JSONResponse { stateTuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomHistoryVisibility, diff --git a/clientapi/routing/capabilities.go b/clientapi/routing/capabilities.go index 72668fa5a..b7d47e916 100644 --- a/clientapi/routing/capabilities.go +++ b/clientapi/routing/capabilities.go @@ -26,7 +26,7 @@ import ( // GetCapabilities returns information about the server's supported feature set // and other relevant capabilities to an authenticated user. func GetCapabilities( - req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, + req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, ) util.JSONResponse { roomVersionsQueryReq := roomserverAPI.QueryRoomVersionCapabilitiesRequest{} roomVersionsQueryRes := roomserverAPI.QueryRoomVersionCapabilitiesResponse{} diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 4976b3e50..d40d84a79 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -137,8 +137,8 @@ type fledglingEvent struct { func CreateRoom( req *http.Request, device *api.Device, cfg *config.ClientAPI, - profileAPI api.UserProfileAPI, rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, + profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, + asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { var r createRoomRequest resErr := httputil.UnmarshalJSONRequest(req, &r) @@ -164,8 +164,8 @@ func createRoom( ctx context.Context, r createRoomRequest, device *api.Device, cfg *config.ClientAPI, - profileAPI api.UserProfileAPI, rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, + profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, + asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, ) util.JSONResponse { // TODO (#267): Check room ID doesn't clash with an existing one, and we @@ -531,25 +531,23 @@ func createRoom( gomatrixserverlib.NewInviteV2StrippedState(inviteEvent.Event), ) // Send the invite event to the roomserver. - err = roomserverAPI.SendInvite( - ctx, - rsAPI, - inviteEvent.Headered(roomVersion), - inviteStrippedState, // invite room state - cfg.Matrix.ServerName, // send as server - nil, // transaction ID - ) - switch e := err.(type) { - case *roomserverAPI.PerformError: - return e.JSONResponse() - case nil: - default: - util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed") + var inviteRes roomserverAPI.PerformInviteResponse + event := inviteEvent.Headered(roomVersion) + if err := rsAPI.PerformInvite(ctx, &roomserverAPI.PerformInviteRequest{ + Event: event, + InviteRoomState: inviteStrippedState, + RoomVersion: event.RoomVersion, + SendAsServer: string(cfg.Matrix.ServerName), + }, &inviteRes); err != nil { + util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), } } + if inviteRes.Error != nil { + return inviteRes.Error.JSONResponse() + } } } diff --git a/clientapi/routing/deactivate.go b/clientapi/routing/deactivate.go index da1b6dcf9..c8aa6a3bc 100644 --- a/clientapi/routing/deactivate.go +++ b/clientapi/routing/deactivate.go @@ -15,7 +15,7 @@ import ( func Deactivate( req *http.Request, userInteractiveAuth *auth.UserInteractive, - accountAPI api.UserAccountAPI, + accountAPI api.ClientUserAPI, deviceAPI *api.Device, ) util.JSONResponse { ctx := req.Context() diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 161bc2731..bb1cf47bd 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -50,7 +50,7 @@ type devicesDeleteJSON struct { // GetDeviceByID handles /devices/{deviceID} func GetDeviceByID( - req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userAPI api.ClientUserAPI, device *api.Device, deviceID string, ) util.JSONResponse { var queryRes api.QueryDevicesResponse @@ -88,7 +88,7 @@ func GetDeviceByID( // GetDevicesByLocalpart handles /devices func GetDevicesByLocalpart( - req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { var queryRes api.QueryDevicesResponse err := userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{ @@ -118,7 +118,7 @@ func GetDevicesByLocalpart( // UpdateDeviceByID handles PUT on /devices/{deviceID} func UpdateDeviceByID( - req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userAPI api.ClientUserAPI, device *api.Device, deviceID string, ) util.JSONResponse { @@ -161,7 +161,7 @@ func UpdateDeviceByID( // DeleteDeviceById handles DELETE requests to /devices/{deviceId} func DeleteDeviceById( - req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.ClientUserAPI, device *api.Device, deviceID string, ) util.JSONResponse { var ( @@ -242,7 +242,7 @@ func DeleteDeviceById( // DeleteDevices handles POST requests to /delete_devices func DeleteDevices( - req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { ctx := req.Context() payload := devicesDeleteJSON{} diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index ac355b5d4..53ba3f190 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -46,8 +46,8 @@ func DirectoryRoom( roomAlias string, federation *gomatrixserverlib.FederationClient, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, - fedSenderAPI federationAPI.FederationInternalAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, + fedSenderAPI federationAPI.ClientFederationAPI, ) util.JSONResponse { _, domain, err := gomatrixserverlib.SplitID('#', roomAlias) if err != nil { @@ -117,7 +117,7 @@ func SetLocalAlias( device *userapi.Device, alias string, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, ) util.JSONResponse { _, domain, err := gomatrixserverlib.SplitID('#', alias) if err != nil { @@ -199,7 +199,7 @@ func RemoveLocalAlias( req *http.Request, device *userapi.Device, alias string, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, ) util.JSONResponse { queryReq := roomserverAPI.RemoveRoomAliasRequest{ Alias: alias, @@ -237,7 +237,7 @@ type roomVisibility struct { // GetVisibility implements GET /directory/list/room/{roomID} func GetVisibility( - req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, + req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, roomID string, ) util.JSONResponse { var res roomserverAPI.QueryPublishedRoomsResponse @@ -265,7 +265,7 @@ func GetVisibility( // SetVisibility implements PUT /directory/list/room/{roomID} // TODO: Allow admin users to edit the room visibility func SetVisibility( - req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, dev *userapi.Device, + req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, roomID string, ) util.JSONResponse { resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID) diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 0dacfced5..c3e6141b2 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -50,7 +50,7 @@ type filter struct { // GetPostPublicRooms implements GET and POST /publicRooms func GetPostPublicRooms( - req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, + req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider, federation *gomatrixserverlib.FederationClient, cfg *config.ClientAPI, @@ -91,7 +91,7 @@ func GetPostPublicRooms( } func publicRooms( - ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.RoomserverInternalAPI, extRoomsProvider api.ExtraPublicRoomsProvider, + ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) (*gomatrixserverlib.RespPublicRooms, error) { response := gomatrixserverlib.RespPublicRooms{ @@ -229,7 +229,7 @@ func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int16) ( } func refreshPublicRoomCache( - ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, extRoomsProvider api.ExtraPublicRoomsProvider, + ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) []gomatrixserverlib.PublicRoom { cacheMu.Lock() defer cacheMu.Unlock() diff --git a/clientapi/routing/getevent.go b/clientapi/routing/getevent.go index 36f3ee9e3..7f5842800 100644 --- a/clientapi/routing/getevent.go +++ b/clientapi/routing/getevent.go @@ -31,7 +31,6 @@ type getEventRequest struct { roomID string eventID string cfg *config.ClientAPI - federation *gomatrixserverlib.FederationClient requestedEvent *gomatrixserverlib.Event } @@ -43,8 +42,7 @@ func GetEvent( roomID string, eventID string, cfg *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, - federation *gomatrixserverlib.FederationClient, + rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { eventsReq := api.QueryEventsByIDRequest{ EventIDs: []string{eventID}, @@ -72,7 +70,6 @@ func GetEvent( roomID: roomID, eventID: eventID, cfg: cfg, - federation: federation, requestedEvent: requestedEvent, } diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index dc15f4bda..4e6acebc3 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -29,8 +29,8 @@ import ( func JoinRoomByIDOrAlias( req *http.Request, device *api.Device, - rsAPI roomserverAPI.RoomserverInternalAPI, - profileAPI api.UserProfileAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, + profileAPI api.ClientUserAPI, roomIDOrAlias string, ) util.JSONResponse { // Prepare to ask the roomserver to perform the room join. diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index 9d2ff87fd..28c80415b 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -55,7 +55,7 @@ type keyBackupSessionResponse struct { // Create a new key backup. Request must contain a `keyBackupVersion`. Returns a `keyBackupVersionCreateResponse`. // Implements POST /_matrix/client/r0/room_keys/version -func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device) util.JSONResponse { +func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device) util.JSONResponse { var kb keyBackupVersion resErr := httputil.UnmarshalJSONRequest(req, &kb) if resErr != nil { @@ -89,7 +89,7 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, // KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned. // Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} -func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { +func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { var queryResp userapi.QueryKeyBackupResponse userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ UserID: device.UserID, @@ -118,7 +118,7 @@ func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device // Modify the auth data of a key backup. Version must not be empty. Request must contain a `keyBackupVersion` // Implements PUT /_matrix/client/r0/room_keys/version/{version} -func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { +func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { var kb keyBackupVersion resErr := httputil.UnmarshalJSONRequest(req, &kb) if resErr != nil { @@ -159,7 +159,7 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.UserInter // Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK. // Implements DELETE /_matrix/client/r0/room_keys/version/{version} -func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { +func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { var performKeyBackupResp userapi.PerformKeyBackupResponse if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, @@ -194,7 +194,7 @@ func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, // Upload a bunch of session keys for a given `version`. func UploadBackupKeys( - req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest, + req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest, ) util.JSONResponse { var performKeyBackupResp userapi.PerformKeyBackupResponse if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ @@ -230,7 +230,7 @@ func UploadBackupKeys( // Get keys from a given backup version. Response returned varies depending on if roomID and sessionID are set. func GetBackupKeys( - req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version, roomID, sessionID string, + req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string, ) util.JSONResponse { var queryResp userapi.QueryKeyBackupResponse userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index c73e0a10d..8fbb86f7a 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -34,8 +34,8 @@ type crossSigningRequest struct { func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, - keyserverAPI api.KeyInternalAPI, device *userapi.Device, - accountAPI userapi.UserAccountAPI, cfg *config.ClientAPI, + keyserverAPI api.ClientKeyAPI, device *userapi.Device, + accountAPI userapi.ClientUserAPI, cfg *config.ClientAPI, ) util.JSONResponse { uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} @@ -105,7 +105,7 @@ func UploadCrossSigningDeviceKeys( } } -func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.KeyInternalAPI, device *userapi.Device) util.JSONResponse { +func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { uploadReq := &api.PerformUploadDeviceSignaturesRequest{} uploadRes := &api.PerformUploadDeviceSignaturesResponse{} diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 2d65ac353..fdda34a53 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -31,7 +31,7 @@ type uploadKeysRequest struct { OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` } -func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.Device) util.JSONResponse { +func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { var r uploadKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -100,7 +100,7 @@ func (r *queryKeysRequest) GetTimeout() time.Duration { return time.Duration(r.Timeout) * time.Millisecond } -func QueryKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.Device) util.JSONResponse { +func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { var r queryKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -138,7 +138,7 @@ func (r *claimKeysRequest) GetTimeout() time.Duration { return time.Duration(r.TimeoutMS) * time.Millisecond } -func ClaimKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse { +func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse { var r claimKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index a34dd02d3..a71661851 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -26,7 +26,7 @@ import ( func LeaveRoomByID( req *http.Request, device *api.Device, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, roomID string, ) util.JSONResponse { // Prepare to ask the roomserver to perform the room join. diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 2329df504..6017b5840 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -53,7 +53,7 @@ func passwordLogin() flows { // Login implements GET and POST /login func Login( - req *http.Request, userAPI userapi.UserInternalAPI, + req *http.Request, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, ) util.JSONResponse { if req.Method == http.MethodGet { @@ -79,7 +79,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login, + ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.ClientUserAPI, login *auth.Login, ipAddr, userAgent string, ) util.JSONResponse { token, err := auth.GenerateAccessToken() diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index cfbb6f9f2..73bae7af7 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -24,7 +24,7 @@ import ( // Logout handles POST /logout func Logout( - req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { var performRes api.PerformDeviceDeletionResponse err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{ @@ -44,7 +44,7 @@ func Logout( // LogoutAll handles POST /logout/all func LogoutAll( - req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + req *http.Request, userAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { var performRes api.PerformDeviceDeletionResponse err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{ diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index df8447b14..cfdf6f2de 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -38,9 +39,9 @@ import ( var errMissingUserID = errors.New("'user_id' must be supplied") func SendBan( - req *http.Request, profileAPI userapi.UserProfileAPI, device *userapi.Device, + req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, roomID string, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { body, evTime, roomVer, reqErr := extractRequestData(req, roomID, rsAPI) if reqErr != nil { @@ -80,10 +81,10 @@ func SendBan( return sendMembership(req.Context(), profileAPI, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) } -func sendMembership(ctx context.Context, profileAPI userapi.UserProfileAPI, device *userapi.Device, +func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, device *userapi.Device, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomVer gomatrixserverlib.RoomVersion, - rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse { + rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI) util.JSONResponse { event, err := buildMembershipEvent( ctx, targetUserID, reason, profileAPI, device, membership, @@ -124,9 +125,9 @@ func sendMembership(ctx context.Context, profileAPI userapi.UserProfileAPI, devi } func SendKick( - req *http.Request, profileAPI userapi.UserProfileAPI, device *userapi.Device, + req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, roomID string, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { body, evTime, roomVer, reqErr := extractRequestData(req, roomID, rsAPI) if reqErr != nil { @@ -164,9 +165,9 @@ func SendKick( } func SendUnban( - req *http.Request, profileAPI userapi.UserProfileAPI, device *userapi.Device, + req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, roomID string, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { body, evTime, roomVer, reqErr := extractRequestData(req, roomID, rsAPI) if reqErr != nil { @@ -199,9 +200,9 @@ func SendUnban( } func SendInvite( - req *http.Request, profileAPI userapi.UserProfileAPI, device *userapi.Device, + req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, roomID string, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { body, evTime, _, reqErr := extractRequestData(req, roomID, rsAPI) if reqErr != nil { @@ -233,12 +234,12 @@ func SendInvite( // sendInvite sends an invitation to a user. Returns a JSONResponse and an error func sendInvite( ctx context.Context, - profileAPI userapi.UserProfileAPI, + profileAPI userapi.ClientUserAPI, device *userapi.Device, roomID, userID, reason string, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, evTime time.Time, + rsAPI roomserverAPI.ClientRoomserverAPI, + asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, ) (util.JSONResponse, error) { event, err := buildMembershipEvent( ctx, userID, reason, profileAPI, device, "invite", @@ -259,37 +260,36 @@ func sendInvite( return jsonerror.InternalServerError(), err } - err = roomserverAPI.SendInvite( - ctx, rsAPI, - event, - nil, // ask the roomserver to draw up invite room state for us - cfg.Matrix.ServerName, - nil, - ) - switch e := err.(type) { - case *roomserverAPI.PerformError: - return e.JSONResponse(), err - case nil: - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - }, nil - default: - util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed") + var inviteRes api.PerformInviteResponse + if err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ + Event: event, + InviteRoomState: nil, // ask the roomserver to draw up invite room state for us + RoomVersion: event.RoomVersion, + SendAsServer: string(cfg.Matrix.ServerName), + }, &inviteRes); err != nil { + util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), }, err } + if inviteRes.Error != nil { + return inviteRes.Error.JSONResponse(), inviteRes.Error + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + }, nil } func buildMembershipEvent( ctx context.Context, - targetUserID, reason string, profileAPI userapi.UserProfileAPI, + targetUserID, reason string, profileAPI userapi.ClientUserAPI, device *userapi.Device, membership, roomID string, isDirect bool, cfg *config.ClientAPI, evTime time.Time, - rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) (*gomatrixserverlib.HeaderedEvent, error) { profile, err := loadProfile(ctx, targetUserID, cfg, profileAPI, asAPI) if err != nil { @@ -326,8 +326,8 @@ func loadProfile( ctx context.Context, userID string, cfg *config.ClientAPI, - profileAPI userapi.UserProfileAPI, - asAPI appserviceAPI.AppServiceQueryAPI, + profileAPI userapi.ClientUserAPI, + asAPI appserviceAPI.AppServiceInternalAPI, ) (*authtypes.Profile, error) { _, serverName, err := gomatrixserverlib.SplitID('@', userID) if err != nil { @@ -344,7 +344,7 @@ func loadProfile( return profile, err } -func extractRequestData(req *http.Request, roomID string, rsAPI roomserverAPI.RoomserverInternalAPI) ( +func extractRequestData(req *http.Request, roomID string, rsAPI roomserverAPI.ClientRoomserverAPI) ( body *threepid.MembershipRequest, evTime time.Time, roomVer gomatrixserverlib.RoomVersion, resErr *util.JSONResponse, ) { verReq := roomserverAPI.QueryRoomVersionForRoomRequest{RoomID: roomID} @@ -379,8 +379,8 @@ func checkAndProcessThreepid( device *userapi.Device, body *threepid.MembershipRequest, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, - profileAPI userapi.UserProfileAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, + profileAPI userapi.ClientUserAPI, roomID string, evTime time.Time, ) (inviteStored bool, errRes *util.JSONResponse) { @@ -418,7 +418,7 @@ func checkAndProcessThreepid( return } -func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, userID, roomID string) *util.JSONResponse { +func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse { tuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomMember, StateKey: userID, @@ -457,7 +457,7 @@ func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.RoomserverIntern func SendForget( req *http.Request, device *userapi.Device, - roomID string, rsAPI roomserverAPI.RoomserverInternalAPI, + roomID string, rsAPI roomserverAPI.ClientRoomserverAPI, ) util.JSONResponse { ctx := req.Context() logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index 6ddcf1be3..9bdd8a4f4 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -55,7 +55,7 @@ type databaseJoinedMember struct { func GetMemberships( req *http.Request, device *userapi.Device, roomID string, joinedOnly bool, _ *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { queryReq := api.QueryMembershipsForRoomRequest{ JoinedOnly: joinedOnly, @@ -100,7 +100,7 @@ func GetMemberships( func GetJoinedRooms( req *http.Request, device *userapi.Device, - rsAPI api.RoomserverInternalAPI, + rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { var res api.QueryRoomsForUserResponse err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go index ee715d323..8a424a141 100644 --- a/clientapi/routing/notification.go +++ b/clientapi/routing/notification.go @@ -27,7 +27,7 @@ import ( // GetNotifications handles /_matrix/client/r0/notifications func GetNotifications( req *http.Request, device *userapi.Device, - userAPI userapi.UserInternalAPI, + userAPI userapi.ClientUserAPI, ) util.JSONResponse { var limit int64 if limitStr := req.URL.Query().Get("limit"); limitStr != "" { diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go index 13656e288..cfb440bea 100644 --- a/clientapi/routing/openid.go +++ b/clientapi/routing/openid.go @@ -34,7 +34,7 @@ type openIDTokenResponse struct { // can supply to an OpenID Relying Party to verify their identity func CreateOpenIDToken( req *http.Request, - userAPI api.UserInternalAPI, + userAPI api.ClientUserAPI, device *api.Device, userID string, cfg *config.ClientAPI, diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 08ce1ffa1..6dc9af508 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -28,7 +28,7 @@ type newPasswordAuth struct { func Password( req *http.Request, - userAPI api.UserInternalAPI, + userAPI api.ClientUserAPI, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index 41d1ff004..d0eeccf17 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -26,7 +26,7 @@ import ( func PeekRoomByIDOrAlias( req *http.Request, device *api.Device, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, roomIDOrAlias string, ) util.JSONResponse { // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to @@ -79,7 +79,7 @@ func PeekRoomByIDOrAlias( func UnpeekRoomByID( req *http.Request, device *api.Device, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, roomID string, ) util.JSONResponse { unpeekReq := roomserverAPI.PerformUnpeekRequest{ diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 3f91b4c93..0685c7352 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -35,9 +35,9 @@ import ( // GetProfile implements GET /profile/{userID} func GetProfile( - req *http.Request, profileAPI userapi.UserProfileAPI, cfg *config.ClientAPI, + req *http.Request, profileAPI userapi.ClientUserAPI, cfg *config.ClientAPI, userID string, - asAPI appserviceAPI.AppServiceQueryAPI, + asAPI appserviceAPI.AppServiceInternalAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation) @@ -64,8 +64,8 @@ func GetProfile( // GetAvatarURL implements GET /profile/{userID}/avatar_url func GetAvatarURL( - req *http.Request, profileAPI userapi.UserProfileAPI, cfg *config.ClientAPI, - userID string, asAPI appserviceAPI.AppServiceQueryAPI, + req *http.Request, profileAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + userID string, asAPI appserviceAPI.AppServiceInternalAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation) @@ -91,8 +91,8 @@ func GetAvatarURL( // SetAvatarURL implements PUT /profile/{userID}/avatar_url func SetAvatarURL( - req *http.Request, profileAPI userapi.UserProfileAPI, - device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, + req *http.Request, profileAPI userapi.ClientUserAPI, + device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ @@ -193,8 +193,8 @@ func SetAvatarURL( // GetDisplayName implements GET /profile/{userID}/displayname func GetDisplayName( - req *http.Request, profileAPI userapi.UserProfileAPI, cfg *config.ClientAPI, - userID string, asAPI appserviceAPI.AppServiceQueryAPI, + req *http.Request, profileAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + userID string, asAPI appserviceAPI.AppServiceInternalAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation) @@ -220,8 +220,8 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, profileAPI userapi.UserProfileAPI, - device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, + req *http.Request, profileAPI userapi.ClientUserAPI, + device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ @@ -325,9 +325,9 @@ func SetDisplayName( // Returns an error when something goes wrong or specifically // eventutil.ErrProfileNoExists when the profile doesn't exist. func getProfile( - ctx context.Context, profileAPI userapi.UserProfileAPI, cfg *config.ClientAPI, + ctx context.Context, profileAPI userapi.ClientUserAPI, cfg *config.ClientAPI, userID string, - asAPI appserviceAPI.AppServiceQueryAPI, + asAPI appserviceAPI.AppServiceInternalAPI, federation *gomatrixserverlib.FederationClient, ) (*authtypes.Profile, error) { localpart, domain, err := gomatrixserverlib.SplitID('@', userID) @@ -366,7 +366,7 @@ func buildMembershipEvents( ctx context.Context, roomIDs []string, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, - evTime time.Time, rsAPI api.RoomserverInternalAPI, + evTime time.Time, rsAPI api.ClientRoomserverAPI, ) ([]*gomatrixserverlib.HeaderedEvent, error) { evs := []*gomatrixserverlib.HeaderedEvent{} diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go index 9d6bef8bd..d6a6eb936 100644 --- a/clientapi/routing/pusher.go +++ b/clientapi/routing/pusher.go @@ -28,7 +28,7 @@ import ( // GetPushers handles /_matrix/client/r0/pushers func GetPushers( req *http.Request, device *userapi.Device, - userAPI userapi.UserInternalAPI, + userAPI userapi.ClientUserAPI, ) util.JSONResponse { var queryRes userapi.QueryPushersResponse localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -57,7 +57,7 @@ func GetPushers( // The behaviour of this endpoint varies depending on the values in the JSON body. func SetPusher( req *http.Request, device *userapi.Device, - userAPI userapi.UserInternalAPI, + userAPI userapi.ClientUserAPI, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go index 81a33b25a..856f52c75 100644 --- a/clientapi/routing/pushrules.go +++ b/clientapi/routing/pushrules.go @@ -30,7 +30,7 @@ func errorResponse(ctx context.Context, err error, msg string, args ...interface return jsonerror.InternalServerError() } -func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { +func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) if err != nil { return errorResponse(ctx, err, "queryPushRulesJSON failed") @@ -41,7 +41,7 @@ func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userap } } -func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { +func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) if err != nil { return errorResponse(ctx, err, "queryPushRulesJSON failed") @@ -56,7 +56,7 @@ func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Devi } } -func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { +func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") @@ -75,7 +75,7 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi } } -func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { +func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") @@ -98,7 +98,7 @@ func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device } } -func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { +func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { var newRule pushrules.Rule if err := json.NewDecoder(body).Decode(&newRule); err != nil { return errorResponse(ctx, err, "JSON Decode failed") @@ -160,7 +160,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} } -func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { +func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") @@ -187,7 +187,7 @@ func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, dev return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} } -func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { +func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { attrGet, err := pushRuleAttrGetter(attr) if err != nil { return errorResponse(ctx, err, "pushRuleAttrGetter failed") @@ -216,7 +216,7 @@ func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri } } -func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse { +func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { var newPartialRule pushrules.Rule if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil { return util.JSONResponse{ @@ -266,7 +266,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} } -func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) { +func queryPushRules(ctx context.Context, userID string, userAPI userapi.ClientUserAPI) (*pushrules.AccountRuleSets, error) { var res userapi.QueryPushRulesResponse if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed") @@ -275,7 +275,7 @@ func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInte return res.RuleSets, nil } -func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error { +func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.ClientUserAPI) error { req := userapi.PerformPushRulesPutRequest{ UserID: userID, RuleSets: ruleSets, diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index e8d14ce34..27f0ba5d0 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -40,7 +40,7 @@ type redactionResponse struct { func SendRedaction( req *http.Request, device *userapi.Device, roomID, eventID string, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, txnID *string, txnCache *transactions.Cache, ) util.JSONResponse { diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 8253f3155..eba4920c6 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -518,7 +518,7 @@ func validateApplicationService( // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register func Register( req *http.Request, - userAPI userapi.UserRegisterAPI, + userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, ) util.JSONResponse { defer req.Body.Close() // nolint: errcheck @@ -614,7 +614,7 @@ func handleGuestRegistration( req *http.Request, r registerRequest, cfg *config.ClientAPI, - userAPI userapi.UserRegisterAPI, + userAPI userapi.ClientUserAPI, ) util.JSONResponse { if cfg.RegistrationDisabled || cfg.GuestsDisabled { return util.JSONResponse{ @@ -679,7 +679,7 @@ func handleRegistrationFlow( r registerRequest, sessionID string, cfg *config.ClientAPI, - userAPI userapi.UserRegisterAPI, + userAPI userapi.ClientUserAPI, accessToken string, accessTokenErr error, ) util.JSONResponse { @@ -768,7 +768,7 @@ func handleApplicationServiceRegistration( req *http.Request, r registerRequest, cfg *config.ClientAPI, - userAPI userapi.UserRegisterAPI, + userAPI userapi.ClientUserAPI, ) util.JSONResponse { // Check if we previously had issues extracting the access token from the // request. @@ -806,7 +806,7 @@ func checkAndCompleteFlow( r registerRequest, sessionID string, cfg *config.ClientAPI, - userAPI userapi.UserRegisterAPI, + userAPI userapi.ClientUserAPI, ) util.JSONResponse { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue @@ -833,7 +833,7 @@ func checkAndCompleteFlow( // not all func completeRegistration( ctx context.Context, - userAPI userapi.UserRegisterAPI, + userAPI userapi.ClientUserAPI, username, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, @@ -992,7 +992,7 @@ type availableResponse struct { func RegisterAvailable( req *http.Request, cfg *config.ClientAPI, - registerAPI userapi.UserRegisterAPI, + registerAPI userapi.ClientUserAPI, ) util.JSONResponse { username := req.URL.Query().Get("username") @@ -1040,7 +1040,7 @@ func RegisterAvailable( } } -func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedSecretRegistration, req *http.Request) util.JSONResponse { +func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSecretRegistration, req *http.Request) util.JSONResponse { ssrr, err := NewSharedSecretRegistrationRequest(req.Body) if err != nil { return util.JSONResponse{ diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index ce173613e..039289569 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -31,7 +31,7 @@ import ( // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags func GetTags( req *http.Request, - userAPI api.UserInternalAPI, + userAPI api.ClientUserAPI, device *api.Device, userID string, roomID string, @@ -62,7 +62,7 @@ func GetTags( // the tag to the "map" and saving the new "map" to the DB func PutTag( req *http.Request, - userAPI api.UserInternalAPI, + userAPI api.ClientUserAPI, device *api.Device, userID string, roomID string, @@ -113,7 +113,7 @@ func PutTag( // the "map" and then saving the new "map" in the DB func DeleteTag( req *http.Request, - userAPI api.UserInternalAPI, + userAPI api.ClientUserAPI, device *api.Device, userID string, roomID string, @@ -167,7 +167,7 @@ func obtainSavedTags( req *http.Request, userID string, roomID string, - userAPI api.UserInternalAPI, + userAPI api.ClientUserAPI, ) (tags gomatrix.TagContent, err error) { dataReq := api.QueryAccountDataRequest{ UserID: userID, @@ -194,7 +194,7 @@ func saveTagData( req *http.Request, userID string, roomID string, - userAPI api.UserInternalAPI, + userAPI api.ClientUserAPI, Tag gomatrix.TagContent, ) error { newTagData, err := json.Marshal(Tag) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index ba1c76b81..94becf465 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -50,15 +50,15 @@ import ( func Setup( publicAPIMux, synapseAdminRouter, dendriteAdminRouter *mux.Router, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, - userAPI userapi.UserInternalAPI, - userDirectoryProvider userapi.UserDirectoryProvider, + rsAPI roomserverAPI.ClientRoomserverAPI, + asAPI appserviceAPI.AppServiceInternalAPI, + userAPI userapi.ClientUserAPI, + userDirectoryProvider userapi.QuerySearchProfilesAPI, federation *gomatrixserverlib.FederationClient, syncProducer *producers.SyncAPIProducer, transactionsCache *transactions.Cache, - federationSender federationAPI.FederationInternalAPI, - keyAPI keyserverAPI.KeyInternalAPI, + federationSender federationAPI.ClientFederationAPI, + keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, natsClient *nats.Conn, ) { @@ -325,7 +325,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, federation) + return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -897,7 +897,7 @@ func Setup( if resErr := clientutil.UnmarshalJSONRequest(req, &postContent); resErr != nil { return *resErr } - return *SearchUserDirectory( + return SearchUserDirectory( req.Context(), device, userAPI, diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 1211fa72d..5f84739d0 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -70,7 +70,7 @@ func SendEvent( device *userapi.Device, roomID, eventType string, txnID, stateKey *string, cfg *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.ClientRoomserverAPI, txnCache *transactions.Cache, ) util.JSONResponse { verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} @@ -207,7 +207,7 @@ func generateSendEvent( device *userapi.Device, roomID, eventType string, stateKey *string, cfg *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.ClientRoomserverAPI, evTime time.Time, ) (*gomatrixserverlib.Event, *util.JSONResponse) { // parse the incoming http request diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 6a27ee615..3f92e4227 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -32,7 +32,7 @@ type typingContentJSON struct { // sends the typing events to client API typingProducer func SendTyping( req *http.Request, device *userapi.Device, roomID string, - userID string, rsAPI roomserverAPI.RoomserverInternalAPI, + userID string, rsAPI roomserverAPI.ClientRoomserverAPI, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { if device.UserID != userID { diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index eec3d7e38..9edeed2f7 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -21,6 +21,7 @@ import ( "net/http" "time" + "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" @@ -55,9 +56,9 @@ func SendServerNotice( req *http.Request, cfgNotices *config.ServerNotices, cfgClient *config.ClientAPI, - userAPI userapi.UserInternalAPI, - rsAPI api.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, + userAPI userapi.ClientUserAPI, + rsAPI api.ClientRoomserverAPI, + asAPI appserviceAPI.AppServiceInternalAPI, device *userapi.Device, senderDevice *userapi.Device, txnID *string, @@ -95,29 +96,16 @@ func SendServerNotice( // get rooms for specified user allUserRooms := []string{} userRooms := api.QueryRoomsForUserResponse{} - if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ - UserID: r.UserID, - WantMembership: "join", - }, &userRooms); err != nil { - return util.ErrorResponse(err) + // Get rooms the user is either joined, invited or has left. + for _, membership := range []string{"join", "invite", "leave"} { + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: membership, + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) } - allUserRooms = append(allUserRooms, userRooms.RoomIDs...) - // get invites for specified user - if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ - UserID: r.UserID, - WantMembership: "invite", - }, &userRooms); err != nil { - return util.ErrorResponse(err) - } - allUserRooms = append(allUserRooms, userRooms.RoomIDs...) - // get left rooms for specified user - if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ - UserID: r.UserID, - WantMembership: "leave", - }, &userRooms); err != nil { - return util.ErrorResponse(err) - } - allUserRooms = append(allUserRooms, userRooms.RoomIDs...) // get rooms of the sender senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) @@ -145,7 +133,7 @@ func SendServerNotice( var ( roomID string - roomVersion = gomatrixserverlib.RoomVersionV6 + roomVersion = version.DefaultRoomVersion() ) // create a new room for the user @@ -194,14 +182,21 @@ func SendServerNotice( // if we didn't get a createRoomResponse, we probably received an error, so return that. return roomRes } - } else { // we've found a room in common, check the membership roomID = commonRooms[0] - // re-invite the user - res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) + membershipRes := api.QueryMembershipForUserResponse{} + err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) if err != nil { - return res + util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") + return jsonerror.InternalServerError() + } + if !membershipRes.IsInRoom { + // re-invite the user + res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) + if err != nil { + return res + } } } @@ -281,7 +276,7 @@ func (r sendServerNoticeRequest) valid() (ok bool) { // It returns an userapi.Device, which is used for building the event func getSenderDevice( ctx context.Context, - userAPI userapi.UserInternalAPI, + userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, ) (*userapi.Device, error) { var accRes userapi.PerformAccountCreationResponse diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index d25ee8237..c6e9e91d0 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -41,7 +41,7 @@ type stateEventInStateResp struct { // TODO: Check if the user is in the room. If not, check if the room's history // is publicly visible. Current behaviour is returning an empty array if the // user cannot see the room's history. -func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI api.RoomserverInternalAPI, roomID string) util.JSONResponse { +func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI api.ClientRoomserverAPI, roomID string) util.JSONResponse { var worldReadable bool var wantLatestState bool @@ -162,7 +162,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // is then (by default) we return the content, otherwise a 404. // If eventFormat=true, sends the whole event else just the content. func OnIncomingStateTypeRequest( - ctx context.Context, device *userapi.Device, rsAPI api.RoomserverInternalAPI, + ctx context.Context, device *userapi.Device, rsAPI api.ClientRoomserverAPI, roomID, evType, stateKey string, eventFormat bool, ) util.JSONResponse { var worldReadable bool diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index a4898ca46..94b658ee3 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -40,7 +40,7 @@ type threePIDsResponse struct { // RequestEmailToken implements: // POST /account/3pid/email/requestToken // POST /register/email/requestToken -func RequestEmailToken(req *http.Request, threePIDAPI api.UserThreePIDAPI, cfg *config.ClientAPI) util.JSONResponse { +func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *config.ClientAPI) util.JSONResponse { var body threepid.EmailAssociationRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr @@ -90,7 +90,7 @@ func RequestEmailToken(req *http.Request, threePIDAPI api.UserThreePIDAPI, cfg * // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( - req *http.Request, threePIDAPI api.UserThreePIDAPI, device *api.Device, + req *http.Request, threePIDAPI api.ClientUserAPI, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest @@ -158,7 +158,7 @@ func CheckAndSave3PIDAssociation( // GetAssociated3PIDs implements GET /account/3pid func GetAssociated3PIDs( - req *http.Request, threepidAPI api.UserThreePIDAPI, device *api.Device, + req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -182,7 +182,7 @@ func GetAssociated3PIDs( } // Forget3PID implements POST /account/3pid/delete -func Forget3PID(req *http.Request, threepidAPI api.UserThreePIDAPI) util.JSONResponse { +func Forget3PID(req *http.Request, threepidAPI api.ClientUserAPI) util.JSONResponse { var body authtypes.ThreePID if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go index 00bde36b3..744e2d889 100644 --- a/clientapi/routing/upgrade_room.go +++ b/clientapi/routing/upgrade_room.go @@ -40,9 +40,9 @@ type upgradeRoomResponse struct { func UpgradeRoom( req *http.Request, device *userapi.Device, cfg *config.ClientAPI, - roomID string, profileAPI userapi.UserProfileAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, + roomID string, profileAPI userapi.ClientUserAPI, + rsAPI roomserverAPI.ClientRoomserverAPI, + asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { var r upgradeRoomRequest if rErr := httputil.UnmarshalJSONRequest(req, &r); rErr != nil { diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index ab73cf430..f311457a0 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -34,13 +34,13 @@ type UserDirectoryResponse struct { func SearchUserDirectory( ctx context.Context, device *userapi.Device, - userAPI userapi.UserInternalAPI, - rsAPI api.RoomserverInternalAPI, - provider userapi.UserDirectoryProvider, + userAPI userapi.ClientUserAPI, + rsAPI api.ClientRoomserverAPI, + provider userapi.QuerySearchProfilesAPI, serverName gomatrixserverlib.ServerName, searchString string, limit int, -) *util.JSONResponse { +) util.JSONResponse { if limit < 10 { limit = 10 } @@ -58,8 +58,7 @@ func SearchUserDirectory( } userRes := &userapi.QuerySearchProfilesResponse{} if err := provider.QuerySearchProfiles(ctx, userReq, userRes); err != nil { - errRes := util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err)) - return &errRes + return util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err)) } for _, user := range userRes.Profiles { @@ -94,8 +93,7 @@ func SearchUserDirectory( } stateRes := &api.QueryKnownUsersResponse{} if err := rsAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil && err != sql.ErrNoRows { - errRes := util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err)) - return &errRes + return util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err)) } for _, user := range stateRes.Users { @@ -114,7 +112,7 @@ func SearchUserDirectory( response.Results = append(response.Results, result) } - return &util.JSONResponse{ + return util.JSONResponse{ Code: 200, JSON: response, } diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 6b750199b..6e7426a7f 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -86,7 +86,7 @@ var ( func CheckAndProcessInvite( ctx context.Context, device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, db userapi.UserProfileAPI, + rsAPI api.ClientRoomserverAPI, db userapi.ClientUserAPI, roomID string, evTime time.Time, ) (inviteStoredOnIDServer bool, err error) { @@ -136,7 +136,7 @@ func CheckAndProcessInvite( // Returns an error if a check or a request failed. func queryIDServer( ctx context.Context, - db userapi.UserProfileAPI, cfg *config.ClientAPI, device *userapi.Device, + userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { if err = isTrusted(body.IDServer, cfg); err != nil { @@ -152,7 +152,7 @@ func queryIDServer( if lookupRes.MXID == "" { // No Matrix ID matches with the given 3PID, ask the server to store the // invite and return a token - storeInviteRes, err = queryIDServerStoreInvite(ctx, db, cfg, device, body, roomID) + storeInviteRes, err = queryIDServerStoreInvite(ctx, userAPI, cfg, device, body, roomID) return } @@ -163,7 +163,7 @@ func queryIDServer( if lookupRes.NotBefore > now || now > lookupRes.NotAfter { // If the current timestamp isn't in the time frame in which the association // is known to be valid, re-run the query - return queryIDServer(ctx, db, cfg, device, body, roomID) + return queryIDServer(ctx, userAPI, cfg, device, body, roomID) } // Check the request signatures and send an error if one isn't valid @@ -205,7 +205,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( ctx context.Context, - db userapi.UserProfileAPI, cfg *config.ClientAPI, device *userapi.Device, + userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name @@ -217,7 +217,7 @@ func queryIDServerStoreInvite( var profile *authtypes.Profile if serverName == cfg.Matrix.ServerName { res := &userapi.QueryProfileResponse{} - err = db.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: device.UserID}, res) + err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: device.UserID}, res) if err != nil { return nil, err } @@ -231,7 +231,7 @@ func queryIDServerStoreInvite( profile = &authtypes.Profile{} } - client := http.Client{} + client := gomatrixserverlib.NewClient() data := url.Values{} data.Add("medium", body.Medium) @@ -253,7 +253,7 @@ func queryIDServerStoreInvite( } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - resp, err := client.Do(req.WithContext(ctx)) + resp, err := client.DoHTTPRequest(ctx, req) if err != nil { return nil, err } @@ -337,7 +337,7 @@ func emit3PIDInviteEvent( ctx context.Context, body *MembershipRequest, res *idServerStoreInviteResponse, device *userapi.Device, roomID string, cfg *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.ClientRoomserverAPI, evTime time.Time, ) error { builder := &gomatrixserverlib.EventBuilder{ diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 2719f8680..7a5660522 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -25,8 +25,8 @@ import ( "strings" "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" "golang.org/x/term" ) @@ -99,8 +99,18 @@ func main() { } } - b := base.NewBaseDendrite(cfg, "Monolith") - accountDB := b.CreateAccountsDB() + accountDB, err := storage.NewUserAPIDatabase( + nil, + &cfg.UserAPI.AccountDatabase, + cfg.Global.ServerName, + cfg.UserAPI.BCryptCost, + cfg.UserAPI.OpenIDTokenLifetimeMS, + 0, // TODO + cfg.Global.ServerNotices.LocalPart, + ) + if err != nil { + logrus.WithError(err).Fatalln("Failed to connect to the database") + } accType := api.AccountTypeUser if *isAdmin { diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 785e7b460..703436051 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -149,7 +149,6 @@ func main() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() federation := conn.CreateFederationClient(base, pQUIC) serverKeyAPI := &signing.YggdrasilKeys{} @@ -162,7 +161,7 @@ func main() { ) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) @@ -174,7 +173,6 @@ func main() { monolith := setup.Monolith{ Config: base.Cfg, - AccountDB: accountDB, Client: conn.CreateClient(base, pQUIC), FedClient: federation, KeyRing: keyRing, @@ -187,16 +185,7 @@ func main() { ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } - monolith.AddAllPublicRoutes( - base.ProcessContext, - base.PublicClientAPIMux, - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - base.PublicMediaAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, - ) + monolith.AddAllPublicRoutes(base) wsUpgrader := websocket.Upgrader{ CheckOrigin: func(_ *http.Request) bool { diff --git a/cmd/dendrite-demo-pinecone/users/users.go b/cmd/dendrite-demo-pinecone/users/users.go index ebfb5cbe3..fc66bf299 100644 --- a/cmd/dendrite-demo-pinecone/users/users.go +++ b/cmd/dendrite-demo-pinecone/users/users.go @@ -37,7 +37,7 @@ import ( type PineconeUserProvider struct { r *pineconeRouter.Router s *pineconeSessions.Sessions - userAPI userapi.UserProfileAPI + userAPI userapi.QuerySearchProfilesAPI fedClient *gomatrixserverlib.FederationClient } @@ -46,7 +46,7 @@ const PublicURL = "/_matrix/p2p/profiles" func NewPineconeUserProvider( r *pineconeRouter.Router, s *pineconeSessions.Sessions, - userAPI userapi.UserProfileAPI, + userAPI userapi.QuerySearchProfilesAPI, fedClient *gomatrixserverlib.FederationClient, ) *PineconeUserProvider { p := &PineconeUserProvider{ diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index f9234319a..619720d6c 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -104,7 +104,6 @@ func main() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() federation := ygg.CreateFederationClient(base) serverKeyAPI := &signing.YggdrasilKeys{} @@ -117,7 +116,7 @@ func main() { ) rsAPI := rsComponent - userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) @@ -130,7 +129,6 @@ func main() { monolith := setup.Monolith{ Config: base.Cfg, - AccountDB: accountDB, Client: ygg.CreateClient(base), FedClient: federation, KeyRing: keyRing, @@ -144,16 +142,7 @@ func main() { ygg, fsAPI, federation, ), } - monolith.AddAllPublicRoutes( - base.ProcessContext, - base.PublicClientAPIMux, - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - base.PublicMediaAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, - ) + monolith.AddAllPublicRoutes(base) if err := mscs.Enable(base, &monolith); err != nil { logrus.WithError(err).Fatalf("Failed to enable MSCs") } diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 5fd5c0b5d..845b9e465 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -71,7 +71,6 @@ func main() { base := basepkg.NewBaseDendrite(cfg, "Monolith", options...) defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() federation := base.CreateFederationClient() rsImpl := roomserver.NewInternalAPI(base) @@ -90,6 +89,7 @@ func main() { fsAPI := federationapi.NewInternalAPI( base, federation, rsAPI, base.Caches, nil, false, ) + fsImplAPI := fsAPI if base.UseHTTPAPIs { federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI) fsAPI = base.FederationAPIHTTPClient() @@ -104,7 +104,7 @@ func main() { } pgClient := base.PushGatewayHTTPClient() - userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) + userImpl := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) userAPI := userImpl if base.UseHTTPAPIs { userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) @@ -135,26 +135,19 @@ func main() { monolith := setup.Monolith{ Config: base.Cfg, - AccountDB: accountDB, Client: base.CreateClient(), FedClient: federation, KeyRing: keyRing, - AppserviceAPI: asAPI, FederationAPI: fsAPI, + AppserviceAPI: asAPI, + // always use the concrete impl here even in -http mode because adding public routes + // must be done on the concrete impl not an HTTP client else fedapi will call itself + FederationAPI: fsImplAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, KeyAPI: keyAPI, } - monolith.AddAllPublicRoutes( - base.ProcessContext, - base.PublicClientAPIMux, - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - base.PublicMediaAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, - ) + monolith.AddAllPublicRoutes(base) if len(base.Cfg.MSCs.MSCs) > 0 { if err := mscs.Enable(base, &monolith); err != nil { diff --git a/cmd/dendrite-polylith-multi/main.go b/cmd/dendrite-polylith-multi/main.go index 6226cc328..e4845f649 100644 --- a/cmd/dendrite-polylith-multi/main.go +++ b/cmd/dendrite-polylith-multi/main.go @@ -31,7 +31,7 @@ import ( type entrypoint func(base *base.BaseDendrite, cfg *config.Dendrite) func main() { - cfg := setup.ParseFlags(true) + cfg := setup.ParseFlags(false) component := "" if flag.NFlag() > 0 { @@ -71,8 +71,8 @@ func main() { logrus.Infof("Starting %q component", component) - base := base.NewBaseDendrite(cfg, component) // TODO - defer base.Close() // nolint: errcheck + base := base.NewBaseDendrite(cfg, component, base.PolylithMode) // TODO + defer base.Close() // nolint: errcheck go start(base, cfg) base.WaitForShutdown() diff --git a/cmd/dendrite-polylith-multi/personalities/clientapi.go b/cmd/dendrite-polylith-multi/personalities/clientapi.go index 7ed2075aa..a5d69d07c 100644 --- a/cmd/dendrite-polylith-multi/personalities/clientapi.go +++ b/cmd/dendrite-polylith-multi/personalities/clientapi.go @@ -31,11 +31,9 @@ func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { keyAPI := base.KeyServerHTTPClient() clientapi.AddPublicRoutes( - base.ProcessContext, base.PublicClientAPIMux, - base.SynapseAdminMux, base.DendriteAdminMux, - &base.Cfg.ClientAPI, federation, rsAPI, asQuery, + base, federation, rsAPI, asQuery, transactions.New(), fsAPI, userAPI, userAPI, - keyAPI, nil, &cfg.MSCs, + keyAPI, nil, ) base.SetupAndServeHTTP( diff --git a/cmd/dendrite-polylith-multi/personalities/federationapi.go b/cmd/dendrite-polylith-multi/personalities/federationapi.go index b82577ce3..6377ce9e3 100644 --- a/cmd/dendrite-polylith-multi/personalities/federationapi.go +++ b/cmd/dendrite-polylith-multi/personalities/federationapi.go @@ -29,10 +29,9 @@ func FederationAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { keyRing := fsAPI.KeyRing() federationapi.AddPublicRoutes( - base.ProcessContext, base.PublicFederationAPIMux, base.PublicKeyAPIMux, base.PublicWellKnownAPIMux, - &base.Cfg.FederationAPI, userAPI, federation, keyRing, - rsAPI, fsAPI, keyAPI, - &base.Cfg.MSCs, nil, + base, + userAPI, federation, keyRing, + rsAPI, fsAPI, keyAPI, nil, ) federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI) diff --git a/cmd/dendrite-polylith-multi/personalities/mediaapi.go b/cmd/dendrite-polylith-multi/personalities/mediaapi.go index fa9d36a38..69d5fd5a8 100644 --- a/cmd/dendrite-polylith-multi/personalities/mediaapi.go +++ b/cmd/dendrite-polylith-multi/personalities/mediaapi.go @@ -24,7 +24,9 @@ func MediaAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { userAPI := base.UserAPIClient() client := base.CreateClient() - mediaapi.AddPublicRoutes(base.PublicMediaAPIMux, &base.Cfg.MediaAPI, &base.Cfg.ClientAPI.RateLimiting, userAPI, client) + mediaapi.AddPublicRoutes( + base, userAPI, client, + ) base.SetupAndServeHTTP( base.Cfg.MediaAPI.InternalAPI.Listen, diff --git a/cmd/dendrite-polylith-multi/personalities/syncapi.go b/cmd/dendrite-polylith-multi/personalities/syncapi.go index 6fee8419b..41637fe1d 100644 --- a/cmd/dendrite-polylith-multi/personalities/syncapi.go +++ b/cmd/dendrite-polylith-multi/personalities/syncapi.go @@ -22,15 +22,13 @@ import ( func SyncAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { userAPI := base.UserAPIClient() - federation := base.CreateFederationClient() rsAPI := base.RoomserverHTTPClient() syncapi.AddPublicRoutes( - base.ProcessContext, - base.PublicClientAPIMux, userAPI, rsAPI, + base, + userAPI, rsAPI, base.KeyServerHTTPClient(), - federation, &cfg.SyncAPI, ) base.SetupAndServeHTTP( diff --git a/cmd/dendrite-polylith-multi/personalities/userapi.go b/cmd/dendrite-polylith-multi/personalities/userapi.go index f1fa379c7..3fe5a43d7 100644 --- a/cmd/dendrite-polylith-multi/personalities/userapi.go +++ b/cmd/dendrite-polylith-multi/personalities/userapi.go @@ -21,10 +21,8 @@ import ( ) func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { - accountDB := base.CreateAccountsDB() - userAPI := userapi.NewInternalAPI( - base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, + base, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient(), base.RoomserverHTTPClient(), base.PushGatewayHTTPClient(), ) diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index 5ecf1f2fb..e070173aa 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -180,7 +180,6 @@ func startup() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() federation := conn.CreateFederationClient(base, pSessions) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) @@ -189,7 +188,7 @@ func startup() { rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) + userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) asQuery := appservice.NewInternalAPI( @@ -201,7 +200,6 @@ func startup() { monolith := setup.Monolith{ Config: base.Cfg, - AccountDB: accountDB, Client: conn.CreateClient(base, pSessions), FedClient: federation, KeyRing: keyRing, @@ -214,16 +212,7 @@ func startup() { //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(pRouter, pSessions, fedSenderAPI, federation), } - monolith.AddAllPublicRoutes( - base.ProcessContext, - base.PublicClientAPIMux, - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - base.PublicMediaAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, - ) + monolith.AddAllPublicRoutes(base) httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 30331fbb3..c52fd6c42 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -45,7 +45,7 @@ func main() { panic(err) } - roomserverDB, err := storage.Open(&cfg.RoomServer.Database, cache) + roomserverDB, err := storage.Open(nil, &cfg.RoomServer.Database, cache) if err != nil { panic(err) } diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 1c11ef96d..7709e0c87 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -54,6 +54,16 @@ global: # considered valid by other homeservers. key_validity_period: 168h0m0s + # Global database connection pool, for PostgreSQL monolith deployments only. If + # this section is populated then you can omit the "database" blocks in all other + # sections. For polylith deployments, or monolith deployments using SQLite databases, + # you must configure the "database" block for each component instead. + # database: + # connection_string: postgres://user:pass@hostname/database?sslmode=disable + # max_open_conns: 100 + # max_idle_conns: 5 + # conn_max_lifetime: -1 + # The server name to delegate server-server communications to, with optional port # e.g. localhost:443 well_known_server_name: "" @@ -75,6 +85,15 @@ global: # Whether outbound presence events are allowed, e.g. sending presence events to other servers enable_outbound: false + # Configures opt-in anonymous stats reporting. + report_stats: + # Whether this instance sends anonymous usage stats + enabled: false + + # The endpoint to report the anonymized homeserver usage statistics to. + # Defaults to https://matrix.org/report-usage-stats/push + endpoint: https://matrix.org/report-usage-stats/push + # Server notices allows server admins to send messages to all users. server_notices: enabled: false diff --git a/docs/FAQ.md b/docs/FAQ.md index 978212cce..47eaecf0f 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -74,3 +74,46 @@ If you are running with `GODEBUG=madvdontneed=1` and still see hugely inflated m ### Dendrite is running out of PostgreSQL database connections You may need to revisit the connection limit of your PostgreSQL server and/or make changes to the `max_connections` lines in your Dendrite configuration. Be aware that each Dendrite component opens its own database connections and has its own connection limit, even in monolith mode! + +### What is being reported when enabling anonymous stats? + +If anonymous stats reporting is enabled, the following data is send to the defined endpoint. + +```json +{ + "cpu_average": 0, + "daily_active_users": 97, + "daily_e2ee_messages": 0, + "daily_messages": 0, + "daily_sent_e2ee_messages": 0, + "daily_sent_messages": 0, + "daily_user_type_bridged": 2, + "daily_user_type_native": 97, + "database_engine": "Postgres", + "database_server_version": "11.14 (Debian 11.14-0+deb10u1)", + "federation_disabled": false, + "go_arch": "amd64", + "go_os": "linux", + "go_version": "go1.16.13", + "homeserver": "localhost:8800", + "log_level": "trace", + "memory_rss": 93452, + "monolith": true, + "monthly_active_users": 97, + "nats_embedded": true, + "nats_in_memory": true, + "num_cpu": 8, + "num_go_routine": 203, + "r30v2_users_all": 0, + "r30v2_users_android": 0, + "r30v2_users_electron": 0, + "r30v2_users_ios": 0, + "r30v2_users_web": 0, + "timestamp": 1651741851, + "total_nonbridged_users": 97, + "total_room_count": 0, + "total_users": 99, + "uptime_seconds": 30, + "version": "0.8.2" +} +``` \ No newline at end of file diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 4d6b0211c..fc25194e0 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -10,21 +10,67 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +// FederationInternalAPI is used to query information from the federation sender. +type FederationInternalAPI interface { + FederationClient + gomatrixserverlib.KeyDatabase + ClientFederationAPI + RoomserverFederationAPI + + QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error + + // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. + PerformBroadcastEDU( + ctx context.Context, + request *PerformBroadcastEDURequest, + response *PerformBroadcastEDUResponse, + ) error +} + +type ClientFederationAPI interface { + // Query the server names of the joined hosts in a room. + // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice + // containing only the server names (without information for membership events). + // The response will include this server if they are joined to the room. + QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error +} + +type RoomserverFederationAPI interface { + gomatrixserverlib.BackfillClient + gomatrixserverlib.FederatedStateClient + KeyRing() *gomatrixserverlib.KeyRing + + // PerformDirectoryLookup looks up a remote room ID from a room alias. + PerformDirectoryLookup(ctx context.Context, request *PerformDirectoryLookupRequest, response *PerformDirectoryLookupResponse) error + // Handle an instruction to make_join & send_join with a remote server. + PerformJoin(ctx context.Context, request *PerformJoinRequest, response *PerformJoinResponse) + // Handle an instruction to make_leave & send_leave with a remote server. + PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse) error + // Handle sending an invite to a remote server. + PerformInvite(ctx context.Context, request *PerformInviteRequest, response *PerformInviteResponse) error + // Handle an instruction to peek a room on a remote server. + PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error + // Query the server names of the joined hosts in a room. + // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice + // containing only the server names (without information for membership events). + // The response will include this server if they are joined to the room. + QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error + GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) +} + // FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError type FederationClient interface { - gomatrixserverlib.BackfillClient gomatrixserverlib.FederatedStateClient GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) - GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. @@ -38,68 +84,6 @@ func (e *FederationClientError) Error() string { return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted) } -// FederationInternalAPI is used to query information from the federation sender. -type FederationInternalAPI interface { - FederationClient - gomatrixserverlib.KeyDatabase - - KeyRing() *gomatrixserverlib.KeyRing - - QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error - - // PerformDirectoryLookup looks up a remote room ID from a room alias. - PerformDirectoryLookup( - ctx context.Context, - request *PerformDirectoryLookupRequest, - response *PerformDirectoryLookupResponse, - ) error - // Query the server names of the joined hosts in a room. - // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice - // containing only the server names (without information for membership events). - // The response will include this server if they are joined to the room. - QueryJoinedHostServerNamesInRoom( - ctx context.Context, - request *QueryJoinedHostServerNamesInRoomRequest, - response *QueryJoinedHostServerNamesInRoomResponse, - ) error - // Handle an instruction to make_join & send_join with a remote server. - PerformJoin( - ctx context.Context, - request *PerformJoinRequest, - response *PerformJoinResponse, - ) - // Handle an instruction to peek a room on a remote server. - PerformOutboundPeek( - ctx context.Context, - request *PerformOutboundPeekRequest, - response *PerformOutboundPeekResponse, - ) error - // Handle an instruction to make_leave & send_leave with a remote server. - PerformLeave( - ctx context.Context, - request *PerformLeaveRequest, - response *PerformLeaveResponse, - ) error - // Handle sending an invite to a remote server. - PerformInvite( - ctx context.Context, - request *PerformInviteRequest, - response *PerformInviteResponse, - ) error - // Notifies the federation sender that these servers may be online and to retry sending messages. - PerformServersAlive( - ctx context.Context, - request *PerformServersAliveRequest, - response *PerformServersAliveResponse, - ) error - // Broadcasts an EDU to all servers in rooms we are joined to. - PerformBroadcastEDU( - ctx context.Context, - request *PerformBroadcastEDURequest, - response *PerformBroadcastEDUResponse, - ) error -} - type QueryServerKeysRequest struct { ServerName gomatrixserverlib.ServerName KeyIDToCriteria map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria @@ -179,13 +163,6 @@ type PerformInviteResponse struct { Event *gomatrixserverlib.HeaderedEvent `json:"event"` } -type PerformServersAliveRequest struct { - Servers []gomatrixserverlib.ServerName -} - -type PerformServersAliveResponse struct { -} - // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomRequest struct { RoomID string `json:"room_id"` diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index af027c3c3..24379154e 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -31,9 +31,7 @@ import ( keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/sirupsen/logrus" @@ -49,20 +47,18 @@ func AddInternalRoutes(router *mux.Router, intAPI api.FederationInternalAPI) { // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. func AddPublicRoutes( - process *process.ProcessContext, - fedRouter, keyRouter, wellKnownRouter *mux.Router, - cfg *config.FederationAPI, + base *base.BaseDendrite, userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, keyRing gomatrixserverlib.JSONVerifier, - rsAPI roomserverAPI.RoomserverInternalAPI, - federationAPI federationAPI.FederationInternalAPI, - keyAPI keyserverAPI.KeyInternalAPI, - mscCfg *config.MSCs, + rsAPI roomserverAPI.FederationRoomserverAPI, + fedAPI federationAPI.FederationInternalAPI, + keyAPI keyserverAPI.FederationKeyAPI, servers federationAPI.ServersInRoomProvider, ) { - - js, _ := jetstream.Prepare(process, &cfg.Matrix.JetStream) + cfg := &base.Cfg.FederationAPI + mscCfg := &base.Cfg.MSCs + js, _ := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) producer := &producers.SyncAPIProducer{ JetStream: js, TopicReceiptEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), @@ -73,9 +69,23 @@ func AddPublicRoutes( UserAPI: userAPI, } + // the federationapi component is a bit unique in that it attaches public routes AND serves + // internal APIs (because it used to be 2 components: the 2nd being fedsender). As a result, + // the constructor shape is a bit wonky in that it is not valid to AddPublicRoutes without a + // concrete impl of FederationInternalAPI as the public routes and the internal API _should_ + // be the same thing now. + f, ok := fedAPI.(*internal.FederationInternalAPI) + if !ok { + panic("federationapi.AddPublicRoutes called with a FederationInternalAPI impl which was not " + + "FederationInternalAPI. This is a programming error.") + } + routing.Setup( - fedRouter, keyRouter, wellKnownRouter, cfg, rsAPI, - federationAPI, keyRing, + base.PublicFederationAPIMux, + base.PublicKeyAPIMux, + base.PublicWellKnownAPIMux, + cfg, + rsAPI, f, keyRing, federation, userAPI, keyAPI, mscCfg, servers, producer, ) @@ -93,7 +103,7 @@ func NewInternalAPI( ) api.FederationInternalAPI { cfg := &base.Cfg.FederationAPI - federationDB, err := storage.NewDatabase(&cfg.Database, base.Caches, base.Cfg.Global.ServerName) + federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.ServerName) if err != nil { logrus.WithError(err).Panic("failed to connect to federation sender db") } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 833359c11..eedebc6cd 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" @@ -27,10 +28,9 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:") base := base.NewBaseDendrite(cfg, "Monolith") keyRing := &test.NopJSONVerifier{} - fsAPI := base.FederationAPIHTTPClient() // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base.ProcessContext, base.PublicFederationAPIMux, base.PublicKeyAPIMux, base.PublicWellKnownAPIMux, &cfg.FederationAPI, nil, nil, keyRing, nil, fsAPI, nil, &cfg.MSCs, nil) + federationapi.AddPublicRoutes(base, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index aac36cc76..577cb70e0 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -563,20 +563,6 @@ func (r *FederationInternalAPI) PerformInvite( return nil } -// PerformServersAlive implements api.FederationInternalAPI -func (r *FederationInternalAPI) PerformServersAlive( - ctx context.Context, - request *api.PerformServersAliveRequest, - response *api.PerformServersAliveResponse, -) (err error) { - for _, srv := range request.Servers { - _ = r.db.RemoveServerFromBlacklist(srv) - r.queues.RetryServer(srv) - } - - return nil -} - // PerformServersAlive implements api.FederationInternalAPI func (r *FederationInternalAPI) PerformBroadcastEDU( ctx context.Context, @@ -600,18 +586,18 @@ func (r *FederationInternalAPI) PerformBroadcastEDU( if err = r.queues.SendEDU(edu, r.cfg.Matrix.ServerName, destinations); err != nil { return fmt.Errorf("r.queues.SendEDU: %w", err) } - - wakeReq := &api.PerformServersAliveRequest{ - Servers: destinations, - } - wakeRes := &api.PerformServersAliveResponse{} - if err := r.PerformServersAlive(ctx, wakeReq, wakeRes); err != nil { - return fmt.Errorf("r.PerformServersAlive: %w", err) - } + r.MarkServersAlive(destinations) return nil } +func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { + for _, srv := range destinations { + _ = r.db.RemoveServerFromBlacklist(srv) + r.queues.RetryServer(srv) + } +} + func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { // sanity check we have a create event and it has a known room version for _, ev := range authChain { diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 01ca6595d..295ddc495 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -23,7 +23,6 @@ const ( FederationAPIPerformLeaveRequestPath = "/federationapi/performLeaveRequest" FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest" FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" - FederationAPIPerformServersAlivePath = "/federationapi/performServersAlive" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" @@ -97,18 +96,6 @@ func (h *httpFederationInternalAPI) PerformOutboundPeek( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } -func (h *httpFederationInternalAPI) PerformServersAlive( - ctx context.Context, - request *api.PerformServersAliveRequest, - response *api.PerformServersAliveResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformServersAlive") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformServersAlivePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - // QueryJoinedHostServerNamesInRoom implements FederationInternalAPI func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom( ctx context.Context, diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index ca4930f20..28e52b32d 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -81,20 +81,6 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle( - FederationAPIPerformServersAlivePath, - httputil.MakeInternalAPI("PerformServersAliveRequest", func(req *http.Request) util.JSONResponse { - var request api.PerformServersAliveRequest - var response api.PerformServersAliveResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.PerformServersAlive(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) internalAPIMux.Handle( FederationAPIPerformBroadcastEDUPath, httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse { diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 82f6cbabf..7b9ca66f6 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -33,7 +33,7 @@ import ( func Backfill( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID string, cfg *config.FederationAPI, ) util.JSONResponse { diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 57286fa90..1a092645f 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -26,7 +26,7 @@ import ( // GetUserDevices for the given user id func GetUserDevices( req *http.Request, - keyAPI keyapi.KeyInternalAPI, + keyAPI keyapi.FederationKeyAPI, userID string, ) util.JSONResponse { var res keyapi.QueryDeviceMessagesResponse diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index e83cb8ad2..868785a9b 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -26,7 +26,7 @@ import ( func GetEventAuth( ctx context.Context, request *gomatrixserverlib.FederationRequest, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID string, eventID string, ) util.JSONResponse { diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index 312ef9f8e..23796edfa 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -29,7 +29,7 @@ import ( func GetEvent( ctx context.Context, request *gomatrixserverlib.FederationRequest, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, eventID string, origin gomatrixserverlib.ServerName, ) util.JSONResponse { @@ -56,7 +56,7 @@ func GetEvent( func allowedToSeeEvent( ctx context.Context, origin gomatrixserverlib.ServerName, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, eventID string, ) *util.JSONResponse { var authResponse api.QueryServerAllowedToSeeEventResponse @@ -82,7 +82,7 @@ func allowedToSeeEvent( } // fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found. -func fetchEvent(ctx context.Context, rsAPI api.RoomserverInternalAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { +func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { var eventsResponse api.QueryEventsByIDResponse err := rsAPI.QueryEventsByID( ctx, diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 58bf99f4a..a5797645e 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -35,7 +35,7 @@ func InviteV2( roomID string, eventID string, cfg *config.FederationAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, ) util.JSONResponse { inviteReq := gomatrixserverlib.InviteV2Request{} @@ -72,7 +72,7 @@ func InviteV1( roomID string, eventID string, cfg *config.FederationAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, ) util.JSONResponse { roomVer := gomatrixserverlib.RoomVersionV1 @@ -110,7 +110,7 @@ func processInvite( roomID string, eventID string, cfg *config.FederationAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, ) util.JSONResponse { @@ -166,31 +166,36 @@ func processInvite( ) // Add the invite event to the roomserver. - err = api.SendInvite( - ctx, rsAPI, signedEvent.Headered(roomVer), strippedState, api.DoNotSendToOtherServers, nil, - ) - switch e := err.(type) { - case *api.PerformError: - return e.JSONResponse() - case nil: - // Return the signed event to the originating server, it should then tell - // the other servers in the room that we have been invited. - if isInviteV2 { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: gomatrixserverlib.RespInviteV2{Event: signedEvent.JSON()}, - } - } else { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: gomatrixserverlib.RespInvite{Event: signedEvent.JSON()}, - } - } - default: - util.GetLogger(ctx).WithError(err).Error("api.SendInvite failed") + inviteEvent := signedEvent.Headered(roomVer) + request := &api.PerformInviteRequest{ + Event: inviteEvent, + InviteRoomState: strippedState, + RoomVersion: inviteEvent.RoomVersion, + SendAsServer: string(api.DoNotSendToOtherServers), + TransactionID: nil, + } + response := &api.PerformInviteResponse{} + if err := rsAPI.PerformInvite(ctx, request, response); err != nil { + util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), } } + if response.Error != nil { + return response.Error.JSONResponse() + } + // Return the signed event to the originating server, it should then tell + // the other servers in the room that we have been invited. + if isInviteV2 { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: gomatrixserverlib.RespInviteV2{Event: signedEvent.JSON()}, + } + } else { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: gomatrixserverlib.RespInvite{Event: signedEvent.JSON()}, + } + } } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 495b8c914..767699728 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -34,7 +34,7 @@ func MakeJoin( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, cfg *config.FederationAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID, userID string, remoteVersions []gomatrixserverlib.RoomVersion, ) util.JSONResponse { @@ -165,7 +165,7 @@ func SendJoin( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, cfg *config.FederationAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, roomID, eventID string, ) util.JSONResponse { diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 49a6c558f..b1a9b6710 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -37,7 +37,7 @@ type queryKeysRequest struct { // QueryDeviceKeys returns device keys for users on this server. // https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-query func QueryDeviceKeys( - httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName, + httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.FederationKeyAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { var qkr queryKeysRequest err := json.Unmarshal(request.Content(), &qkr) @@ -89,7 +89,7 @@ type claimOTKsRequest struct { // ClaimOneTimeKeys claims OTKs for users on this server. // https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-claim func ClaimOneTimeKeys( - httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName, + httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.FederationKeyAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { var cor claimOTKsRequest err := json.Unmarshal(request.Content(), &cor) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 0b83f04ae..54b2c3e84 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -30,7 +30,7 @@ func MakeLeave( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, cfg *config.FederationAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID, userID string, ) util.JSONResponse { _, domain, err := gomatrixserverlib.SplitID('@', userID) @@ -122,7 +122,7 @@ func SendLeave( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, cfg *config.FederationAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, roomID, eventID string, ) util.JSONResponse { diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index b826d69c4..531cb9e28 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -34,7 +34,7 @@ type getMissingEventRequest struct { func GetMissingEvents( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID string, ) util.JSONResponse { var gme getMissingEventRequest diff --git a/federationapi/routing/openid.go b/federationapi/routing/openid.go index 829dbccad..cbc75a9a7 100644 --- a/federationapi/routing/openid.go +++ b/federationapi/routing/openid.go @@ -30,7 +30,7 @@ type openIDUserInfoResponse struct { // GetOpenIDUserInfo implements GET /_matrix/federation/v1/openid/userinfo func GetOpenIDUserInfo( httpReq *http.Request, - userAPI userapi.UserInternalAPI, + userAPI userapi.FederationUserAPI, ) util.JSONResponse { token := httpReq.URL.Query().Get("access_token") if len(token) == 0 { diff --git a/federationapi/routing/peek.go b/federationapi/routing/peek.go index 827d1116d..bc4dac90f 100644 --- a/federationapi/routing/peek.go +++ b/federationapi/routing/peek.go @@ -29,7 +29,7 @@ func Peek( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, cfg *config.FederationAPI, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID, peekID string, remoteVersions []gomatrixserverlib.RoomVersion, ) util.JSONResponse { diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index dbc209ce1..f672811af 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -29,7 +29,7 @@ import ( // GetProfile implements GET /_matrix/federation/v1/query/profile func GetProfile( httpReq *http.Request, - userAPI userapi.UserInternalAPI, + userAPI userapi.FederationUserAPI, cfg *config.FederationAPI, ) util.JSONResponse { userID, field := httpReq.FormValue("user_id"), httpReq.FormValue("field") diff --git a/federationapi/routing/publicrooms.go b/federationapi/routing/publicrooms.go index a253f86eb..1a54f5a7d 100644 --- a/federationapi/routing/publicrooms.go +++ b/federationapi/routing/publicrooms.go @@ -23,7 +23,7 @@ type filter struct { } // GetPostPublicRooms implements GET and POST /publicRooms -func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI) util.JSONResponse { +func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.FederationRoomserverAPI) util.JSONResponse { var request PublicRoomReq if fillErr := fillPublicRoomsReq(req, &request); fillErr != nil { return *fillErr @@ -42,7 +42,7 @@ func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.RoomserverInterna } func publicRooms( - ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.RoomserverInternalAPI, + ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.FederationRoomserverAPI, ) (*gomatrixserverlib.RespPublicRooms, error) { var response gomatrixserverlib.RespPublicRooms @@ -111,7 +111,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO } // due to lots of switches -func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { +func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.FederationRoomserverAPI) ([]gomatrixserverlib.PublicRoom, error) { avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 47d3b2df9..707b7b019 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -32,7 +32,7 @@ func RoomAliasToID( httpReq *http.Request, federation *gomatrixserverlib.FederationClient, cfg *config.FederationAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.FederationRoomserverAPI, senderAPI federationAPI.FederationInternalAPI, ) util.JSONResponse { roomAlias := httpReq.FormValue("room_alias") diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 6d24c8b40..9f95ed07e 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -18,10 +18,14 @@ import ( "context" "fmt" "net/http" + "sync" + "time" + "github.com/getsentry/sentry-go" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" @@ -47,12 +51,12 @@ import ( func Setup( fedMux, keyMux, wkMux *mux.Router, cfg *config.FederationAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, - fsAPI federationAPI.FederationInternalAPI, + rsAPI roomserverAPI.FederationRoomserverAPI, + fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, - userAPI userapi.UserInternalAPI, - keyAPI keyserverAPI.KeyInternalAPI, + userAPI userapi.FederationUserAPI, + keyAPI keyserverAPI.FederationKeyAPI, mscCfg *config.MSCs, servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, @@ -65,7 +69,7 @@ func Setup( v1fedmux := fedMux.PathPrefix("/v1").Subrouter() v2fedmux := fedMux.PathPrefix("/v2").Subrouter() - wakeup := &httputil.FederationWakeups{ + wakeup := &FederationWakeups{ FsAPI: fsAPI, } @@ -119,7 +123,7 @@ func Setup( v2keysmux.Handle("/query/{serverName}/{keyID}", notaryKeys).Methods(http.MethodGet) mu := internal.NewMutexByRoom() - v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI( + v1fedmux.Handle("/send/{txnID}", MakeFedAPI( "federation_send", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( @@ -129,7 +133,7 @@ func Setup( }, )).Methods(http.MethodPut, http.MethodOptions) - v1fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI( + v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -145,7 +149,7 @@ func Setup( }, )).Methods(http.MethodPut, http.MethodOptions) - v2fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI( + v2fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -167,7 +171,7 @@ func Setup( }, )).Methods(http.MethodPost, http.MethodOptions) - v1fedmux.Handle("/exchange_third_party_invite/{roomID}", httputil.MakeFedAPI( + v1fedmux.Handle("/exchange_third_party_invite/{roomID}", MakeFedAPI( "exchange_third_party_invite", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ExchangeThirdPartyInvite( @@ -176,7 +180,7 @@ func Setup( }, )).Methods(http.MethodPut, http.MethodOptions) - v1fedmux.Handle("/event/{eventID}", httputil.MakeFedAPI( + v1fedmux.Handle("/event/{eventID}", MakeFedAPI( "federation_get_event", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetEvent( @@ -185,7 +189,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/state/{roomID}", httputil.MakeFedAPI( + v1fedmux.Handle("/state/{roomID}", MakeFedAPI( "federation_get_state", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -200,7 +204,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/state_ids/{roomID}", httputil.MakeFedAPI( + v1fedmux.Handle("/state_ids/{roomID}", MakeFedAPI( "federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -215,7 +219,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/event_auth/{roomID}/{eventID}", httputil.MakeFedAPI( + v1fedmux.Handle("/event_auth/{roomID}/{eventID}", MakeFedAPI( "federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -230,7 +234,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/query/directory", httputil.MakeFedAPI( + v1fedmux.Handle("/query/directory", MakeFedAPI( "federation_query_room_alias", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return RoomAliasToID( @@ -239,7 +243,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/query/profile", httputil.MakeFedAPI( + v1fedmux.Handle("/query/profile", MakeFedAPI( "federation_query_profile", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetProfile( @@ -248,7 +252,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/user/devices/{userID}", httputil.MakeFedAPI( + v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( "federation_user_devices", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( @@ -258,7 +262,7 @@ func Setup( )).Methods(http.MethodGet) if mscCfg.Enabled("msc2444") { - v1fedmux.Handle("/peek/{roomID}/{peekID}", httputil.MakeFedAPI( + v1fedmux.Handle("/peek/{roomID}/{peekID}", MakeFedAPI( "federation_peek", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -289,7 +293,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodDelete) } - v1fedmux.Handle("/make_join/{roomID}/{userID}", httputil.MakeFedAPI( + v1fedmux.Handle("/make_join/{roomID}/{userID}", MakeFedAPI( "federation_make_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -320,7 +324,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( + v1fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -352,7 +356,7 @@ func Setup( }, )).Methods(http.MethodPut) - v2fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( + v2fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -369,7 +373,7 @@ func Setup( }, )).Methods(http.MethodPut) - v1fedmux.Handle("/make_leave/{roomID}/{eventID}", httputil.MakeFedAPI( + v1fedmux.Handle("/make_leave/{roomID}/{eventID}", MakeFedAPI( "federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -386,7 +390,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/send_leave/{roomID}/{eventID}", httputil.MakeFedAPI( + v1fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -418,7 +422,7 @@ func Setup( }, )).Methods(http.MethodPut) - v2fedmux.Handle("/send_leave/{roomID}/{eventID}", httputil.MakeFedAPI( + v2fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -442,7 +446,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/get_missing_events/{roomID}", httputil.MakeFedAPI( + v1fedmux.Handle("/get_missing_events/{roomID}", MakeFedAPI( "federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -455,7 +459,7 @@ func Setup( }, )).Methods(http.MethodPost) - v1fedmux.Handle("/backfill/{roomID}", httputil.MakeFedAPI( + v1fedmux.Handle("/backfill/{roomID}", MakeFedAPI( "federation_backfill", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -474,14 +478,14 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost) - v1fedmux.Handle("/user/keys/claim", httputil.MakeFedAPI( + v1fedmux.Handle("/user/keys/claim", MakeFedAPI( "federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) - v1fedmux.Handle("/user/keys/query", httputil.MakeFedAPI( + v1fedmux.Handle("/user/keys/query", MakeFedAPI( "federation_keys_query", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) @@ -497,7 +501,7 @@ func Setup( func ErrorIfLocalServerNotInRoom( ctx context.Context, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID string, ) *util.JSONResponse { // Check if we think we're in this room. If we aren't then @@ -518,3 +522,67 @@ func ErrorIfLocalServerNotInRoom( } return nil } + +// MakeFedAPI makes an http.Handler that checks matrix federation authentication. +func MakeFedAPI( + metricsName string, + serverName gomatrixserverlib.ServerName, + keyRing gomatrixserverlib.JSONVerifier, + wakeup *FederationWakeups, + f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + req, time.Now(), serverName, keyRing, + ) + if fedReq == nil { + return errResp + } + // add the user to Sentry, if enabled + hub := sentry.GetHubFromContext(req.Context()) + if hub != nil { + hub.Scope().SetTag("origin", string(fedReq.Origin())) + hub.Scope().SetTag("uri", fedReq.RequestURI()) + } + defer func() { + if r := recover(); r != nil { + if hub != nil { + hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) + } + // re-panic to return the 500 + panic(r) + } + }() + go wakeup.Wakeup(req.Context(), fedReq.Origin()) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + } + + jsonRes := f(req, fedReq, vars) + // do not log 4xx as errors as they are client fails, not server fails + if hub != nil && jsonRes.Code >= 500 { + hub.Scope().SetExtra("response", jsonRes) + hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) + } + return jsonRes + } + return httputil.MakeExternalAPI(metricsName, h) +} + +type FederationWakeups struct { + FsAPI *fedInternal.FederationInternalAPI + origins sync.Map +} + +func (f *FederationWakeups) Wakeup(ctx context.Context, origin gomatrixserverlib.ServerName) { + key, keyok := f.origins.Load(origin) + if keyok { + lastTime, ok := key.(time.Time) + if ok && time.Since(lastTime) < time.Minute { + return + } + } + f.FsAPI.MarkServersAlive([]gomatrixserverlib.ServerName{origin}) + f.origins.Store(origin, time.Now()) +} diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 2c01afb1b..55a113675 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -82,8 +82,8 @@ func Send( request *gomatrixserverlib.FederationRequest, txnID gomatrixserverlib.TransactionID, cfg *config.FederationAPI, - rsAPI api.RoomserverInternalAPI, - keyAPI keyapi.KeyInternalAPI, + rsAPI api.FederationRoomserverAPI, + keyAPI keyapi.FederationKeyAPI, keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, mu *internal.MutexByRoom, @@ -182,8 +182,8 @@ func Send( type txnReq struct { gomatrixserverlib.Transaction - rsAPI api.RoomserverInternalAPI - keyAPI keyapi.KeyInternalAPI + rsAPI api.FederationRoomserverAPI + keyAPI keyapi.FederationKeyAPI ourServerName gomatrixserverlib.ServerName keys gomatrixserverlib.JSONVerifier federation txnFederationClient diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 8d2d85040..011d4e342 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -183,7 +183,7 @@ func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserver return c.getMissingEvents(missing) } -func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { +func mustCreateTransaction(rsAPI api.FederationRoomserverAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { t := &txnReq{ rsAPI: rsAPI, keys: &test.NopJSONVerifier{}, diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index e2b67776a..6fdce20ce 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -27,7 +27,7 @@ import ( func GetState( ctx context.Context, request *gomatrixserverlib.FederationRequest, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID string, ) util.JSONResponse { eventID, err := parseEventIDParam(request) @@ -50,7 +50,7 @@ func GetState( func GetStateIDs( ctx context.Context, request *gomatrixserverlib.FederationRequest, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID string, ) util.JSONResponse { eventID, err := parseEventIDParam(request) @@ -97,7 +97,7 @@ func parseEventIDParam( func getState( ctx context.Context, request *gomatrixserverlib.FederationRequest, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, roomID string, eventID string, ) (stateEvents, authEvents []*gomatrixserverlib.HeaderedEvent, errRes *util.JSONResponse) { diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 8ae7130c3..16f245cee 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -55,10 +55,10 @@ var ( // CreateInvitesFrom3PIDInvites implements POST /_matrix/federation/v1/3pid/onbind func CreateInvitesFrom3PIDInvites( - req *http.Request, rsAPI api.RoomserverInternalAPI, + req *http.Request, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, federation *gomatrixserverlib.FederationClient, - userAPI userapi.UserInternalAPI, + userAPI userapi.FederationUserAPI, ) util.JSONResponse { var body invites if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { @@ -105,7 +105,7 @@ func ExchangeThirdPartyInvite( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, roomID string, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -203,10 +203,10 @@ func ExchangeThirdPartyInvite( // Returns an error if there was a problem building the event or fetching the // necessary data to do so. func createInviteFrom3PIDInvite( - ctx context.Context, rsAPI api.RoomserverInternalAPI, + ctx context.Context, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, inv invite, federation *gomatrixserverlib.FederationClient, - userAPI userapi.UserInternalAPI, + userAPI userapi.FederationUserAPI, ) (*gomatrixserverlib.Event, error) { verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} verRes := api.QueryRoomVersionForRoomResponse{} @@ -270,7 +270,7 @@ func createInviteFrom3PIDInvite( // Returns an error if something failed during the process. func buildMembershipEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, rsAPI api.RoomserverInternalAPI, + builder *gomatrixserverlib.EventBuilder, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, ) (*gomatrixserverlib.Event, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 993bf7945..ebaeeb40e 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -35,13 +36,12 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { var d Database var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { return nil, err } - d.writer = sqlutil.NewDummyWriter() joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index 63cebfaa1..b627f6868 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -34,13 +35,12 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { var d Database var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } - d.writer = sqlutil.NewExclusiveWriter() joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/storage.go b/federationapi/storage/storage.go index 4b52ca206..f246b9bc9 100644 --- a/federationapi/storage/storage.go +++ b/federationapi/storage/storage.go @@ -23,17 +23,18 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/postgres" "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, cache, serverName) + return sqlite3.NewDatabase(base, dbProperties, cache, serverName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, cache, serverName) + return postgres.NewDatabase(base, dbProperties, cache, serverName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/federationapi/storage/storage_wasm.go b/federationapi/storage/storage_wasm.go index 09abed63e..84d5a3a4c 100644 --- a/federationapi/storage/storage_wasm.go +++ b/federationapi/storage/storage_wasm.go @@ -19,15 +19,16 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, cache, serverName) + return sqlite3.NewDatabase(base, dbProperties, cache, serverName) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/go.mod b/go.mod index a7caadfb5..58cfbba11 100644 --- a/go.mod +++ b/go.mod @@ -25,12 +25,12 @@ require ( github.com/h2non/filetype v1.1.3 // indirect github.com/hashicorp/golang-lru v0.5.4 github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect - github.com/kardianos/minwinsvc v1.0.0 // indirect + github.com/kardianos/minwinsvc v1.0.0 github.com/lib/pq v1.10.5 github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f + github.com/matrix-org/gomatrixserverlib v0.0.0-20220506144035-8958f9dfdffd github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index f8daca79e..047bbf5c3 100644 --- a/go.sum +++ b/go.sum @@ -795,8 +795,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f h1:MZrl4TgTnlaOn2Cu9gJCoJ3oyW5mT4/3QIZGgZXzKl4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220506144035-8958f9dfdffd h1:11Wh+NMPDE5UDEx50RJnxeYj7zH5HEClzGfVMAjUy9U= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220506144035-8958f9dfdffd/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/internal/caching/cache_lazy_load_members.go b/internal/caching/cache_lazy_load_members.go index 71a317624..f0d495065 100644 --- a/internal/caching/cache_lazy_load_members.go +++ b/internal/caching/cache_lazy_load_members.go @@ -15,33 +15,14 @@ const ( LazyLoadCacheMaxAge = time.Minute * 30 ) -type LazyLoadCache struct { - // InMemoryLRUCachePartition containing other InMemoryLRUCachePartitions - // with the actual cached members - userCaches *InMemoryLRUCachePartition +type LazyLoadCache interface { + StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) + IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) } -// NewLazyLoadCache creates a new LazyLoadCache. -func NewLazyLoadCache() (*LazyLoadCache, error) { - cache, err := NewInMemoryLRUCachePartition( - LazyLoadCacheName, - LazyLoadCacheMutable, - LazyLoadCacheMaxEntries, - LazyLoadCacheMaxAge, - true, - ) - if err != nil { - return nil, err - } - go cacheCleaner(cache) - return &LazyLoadCache{ - userCaches: cache, - }, nil -} - -func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) { +func (c Caches) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) { cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID) - userCache, ok := c.userCaches.Get(cacheName) + userCache, ok := c.LazyLoading.Get(cacheName) if ok && userCache != nil { if cache, ok := userCache.(*InMemoryLRUCachePartition); ok { return cache, nil @@ -57,12 +38,12 @@ func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryL if err != nil { return nil, err } - c.userCaches.Set(cacheName, cache) + c.LazyLoading.Set(cacheName, cache) go cacheCleaner(cache) return cache, nil } -func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) { +func (c Caches) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) { cache, err := c.lazyLoadCacheForUser(device) if err != nil { return @@ -71,7 +52,7 @@ func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, user cache.Set(cacheKey, eventID) } -func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) { +func (c Caches) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) { cache, err := c.lazyLoadCacheForUser(device) if err != nil { return "", false diff --git a/internal/caching/caches.go b/internal/caching/caches.go index 722405de6..173e47e5b 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -1,6 +1,8 @@ package caching -import "time" +import ( + "time" +) // Caches contains a set of references to caches. They may be // different implementations as long as they satisfy the Cache @@ -13,6 +15,7 @@ type Caches struct { RoomInfos Cache // RoomInfoCache FederationEvents Cache // FederationEventsCache SpaceSummaryRooms Cache // SpaceSummaryRoomsCache + LazyLoading Cache // LazyLoadCache } // Cache is the interface that an implementation must satisfy. diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index 94fdd1a9b..594760892 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -70,9 +70,21 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { if err != nil { return nil, err } + + lazyLoadCache, err := NewInMemoryLRUCachePartition( + LazyLoadCacheName, + LazyLoadCacheMutable, + LazyLoadCacheMaxEntries, + LazyLoadCacheMaxAge, + enablePrometheus, + ) + if err != nil { + return nil, err + } + go cacheCleaner( roomVersions, serverKeys, roomServerRoomIDs, - roomInfos, federationEvents, spaceRooms, + roomInfos, federationEvents, spaceRooms, lazyLoadCache, ) return &Caches{ RoomVersions: roomVersions, @@ -81,6 +93,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomInfos: roomInfos, FederationEvents: federationEvents, SpaceSummaryRooms: spaceRooms, + LazyLoading: lazyLoadCache, }, nil } diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index 47c83d515..ee67a6daf 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -39,7 +39,7 @@ var ErrRoomNoExists = errors.New("room does not exist") func QueryAndBuildEvent( ctx context.Context, builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, - rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse, + rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.HeaderedEvent, error) { if queryRes == nil { queryRes = &api.QueryLatestEventsAndStateResponse{} @@ -80,7 +80,7 @@ func BuildEvent( func queryRequiredEventsForBuilder( ctx context.Context, builder *gomatrixserverlib.EventBuilder, - rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse, + rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.StateNeeded, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) if err != nil { diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 5fcacd2ad..aba50ae4d 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -15,7 +15,6 @@ package httputil import ( - "context" "fmt" "io" "net/http" @@ -23,15 +22,10 @@ import ( "net/http/httputil" "os" "strings" - "sync" - "time" "github.com/getsentry/sentry-go" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/auth" - federationapiAPI "github.com/matrix-org/dendrite/federationapi/api" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" @@ -49,7 +43,7 @@ type BasicAuth struct { // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( - metricsName string, userAPI userapi.UserInternalAPI, + metricsName string, userAPI userapi.QueryAcccessTokenAPI, f func(*http.Request, *userapi.Device) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { @@ -226,79 +220,6 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse ) } -// MakeFedAPI makes an http.Handler that checks matrix federation authentication. -func MakeFedAPI( - metricsName string, - serverName gomatrixserverlib.ServerName, - keyRing gomatrixserverlib.JSONVerifier, - wakeup *FederationWakeups, - f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, -) http.Handler { - h := func(req *http.Request) util.JSONResponse { - fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), serverName, keyRing, - ) - if fedReq == nil { - return errResp - } - // add the user to Sentry, if enabled - hub := sentry.GetHubFromContext(req.Context()) - if hub != nil { - hub.Scope().SetTag("origin", string(fedReq.Origin())) - hub.Scope().SetTag("uri", fedReq.RequestURI()) - } - defer func() { - if r := recover(); r != nil { - if hub != nil { - hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) - } - // re-panic to return the 500 - panic(r) - } - }() - go wakeup.Wakeup(req.Context(), fedReq.Origin()) - vars, err := URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") - } - - jsonRes := f(req, fedReq, vars) - // do not log 4xx as errors as they are client fails, not server fails - if hub != nil && jsonRes.Code >= 500 { - hub.Scope().SetExtra("response", jsonRes) - hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) - } - return jsonRes - } - return MakeExternalAPI(metricsName, h) -} - -type FederationWakeups struct { - FsAPI federationapiAPI.FederationInternalAPI - origins sync.Map -} - -func (f *FederationWakeups) Wakeup(ctx context.Context, origin gomatrixserverlib.ServerName) { - key, keyok := f.origins.Load(origin) - if keyok { - lastTime, ok := key.(time.Time) - if ok && time.Since(lastTime) < time.Minute { - return - } - } - aliveReq := federationapiAPI.PerformServersAliveRequest{ - Servers: []gomatrixserverlib.ServerName{origin}, - } - aliveRes := federationapiAPI.PerformServersAliveResponse{} - if err := f.FsAPI.PerformServersAlive(ctx, &aliveReq, &aliveRes); err != nil { - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "origin": origin, - }).Warn("incoming federation request failed to notify server alive") - } else { - f.origins.Store(origin, time.Now()) - } -} - // WrapHandlerInBasicAuth adds basic auth to a handler. Only used for /metrics func WrapHandlerInBasicAuth(h http.Handler, b BasicAuth) http.HandlerFunc { if b.Username == "" || b.Password == "" { diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go index 49907cee8..231327a1e 100644 --- a/internal/pushgateway/client.go +++ b/internal/pushgateway/client.go @@ -3,31 +3,28 @@ package pushgateway import ( "bytes" "context" - "crypto/tls" "encoding/json" "fmt" "net/http" "time" + "github.com/matrix-org/gomatrixserverlib" "github.com/opentracing/opentracing-go" ) type httpClient struct { - hc *http.Client + hc *gomatrixserverlib.Client } // NewHTTPClient creates a new Push Gateway client. func NewHTTPClient(disableTLSValidation bool) Client { - hc := &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{ - DisableKeepAlives: true, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: disableTLSValidation, - }, - }, + return &httpClient{ + hc: gomatrixserverlib.NewClient( + gomatrixserverlib.WithTimeout(time.Second*30), + gomatrixserverlib.WithKeepAlives(false), + gomatrixserverlib.WithSkipVerify(disableTLSValidation), + ), } - return &httpClient{hc: hc} } func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error { @@ -44,7 +41,7 @@ func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, } hreq.Header.Set("Content-Type", "application/json") - hresp, err := h.hc.Do(hreq) + hresp, err := h.hc.DoHTTPRequest(ctx, hreq) if err != nil { return err } diff --git a/internal/sqlutil/sqlutil.go b/internal/sqlutil/sqlutil.go new file mode 100644 index 000000000..0cdae6d30 --- /dev/null +++ b/internal/sqlutil/sqlutil.go @@ -0,0 +1,51 @@ +package sqlutil + +import ( + "database/sql" + "fmt" + "regexp" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/sirupsen/logrus" +) + +// Open opens a database specified by its database driver name and a driver-specific data source name, +// usually consisting of at least a database name and connection information. Includes tracing driver +// if DENDRITE_TRACE_SQL=1 +func Open(dbProperties *config.DatabaseOptions, writer Writer) (*sql.DB, error) { + var err error + var driverName, dsn string + switch { + case dbProperties.ConnectionString.IsSQLite(): + driverName = "sqlite3" + dsn, err = ParseFileURI(dbProperties.ConnectionString) + if err != nil { + return nil, fmt.Errorf("ParseFileURI: %w", err) + } + case dbProperties.ConnectionString.IsPostgres(): + driverName = "postgres" + dsn = string(dbProperties.ConnectionString) + default: + return nil, fmt.Errorf("invalid database connection string %q", dbProperties.ConnectionString) + } + if tracingEnabled { + // install the wrapped driver + driverName += "-trace" + } + db, err := sql.Open(driverName, dsn) + if err != nil { + return nil, err + } + if driverName != "sqlite3" { + logrus.WithFields(logrus.Fields{ + "MaxOpenConns": dbProperties.MaxOpenConns(), + "MaxIdleConns": dbProperties.MaxIdleConns(), + "ConnMaxLifetime": dbProperties.ConnMaxLifetime(), + "dataSourceName": regexp.MustCompile(`://[^@]*@`).ReplaceAllLiteralString(dsn, "://"), + }).Debug("Setting DB connection limits") + db.SetMaxOpenConns(dbProperties.MaxOpenConns()) + db.SetMaxIdleConns(dbProperties.MaxIdleConns()) + db.SetConnMaxLifetime(dbProperties.ConnMaxLifetime()) + } + return db, nil +} diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go index 51eaa1b45..c16738616 100644 --- a/internal/sqlutil/trace.go +++ b/internal/sqlutil/trace.go @@ -16,19 +16,16 @@ package sqlutil import ( "context" - "database/sql" "database/sql/driver" "fmt" "io" "os" - "regexp" "runtime" "strconv" "strings" "sync" "time" - "github.com/matrix-org/dendrite/setup/config" "github.com/ngrok/sqlmw" "github.com/sirupsen/logrus" ) @@ -96,47 +93,6 @@ func trackGoID(query string) { logrus.Warnf("unsafe goid %d: SQL executed not on an ExclusiveWriter: %s", thisGoID, q) } -// Open opens a database specified by its database driver name and a driver-specific data source name, -// usually consisting of at least a database name and connection information. Includes tracing driver -// if DENDRITE_TRACE_SQL=1 -func Open(dbProperties *config.DatabaseOptions) (*sql.DB, error) { - var err error - var driverName, dsn string - switch { - case dbProperties.ConnectionString.IsSQLite(): - driverName = "sqlite3" - dsn, err = ParseFileURI(dbProperties.ConnectionString) - if err != nil { - return nil, fmt.Errorf("ParseFileURI: %w", err) - } - case dbProperties.ConnectionString.IsPostgres(): - driverName = "postgres" - dsn = string(dbProperties.ConnectionString) - default: - return nil, fmt.Errorf("invalid database connection string %q", dbProperties.ConnectionString) - } - if tracingEnabled { - // install the wrapped driver - driverName += "-trace" - } - db, err := sql.Open(driverName, dsn) - if err != nil { - return nil, err - } - if driverName != "sqlite3" { - logrus.WithFields(logrus.Fields{ - "MaxOpenConns": dbProperties.MaxOpenConns(), - "MaxIdleConns": dbProperties.MaxIdleConns(), - "ConnMaxLifetime": dbProperties.ConnMaxLifetime(), - "dataSourceName": regexp.MustCompile(`://[^@]*@`).ReplaceAllLiteralString(dsn, "://"), - }).Debug("Setting DB connection limits") - db.SetMaxOpenConns(dbProperties.MaxOpenConns()) - db.SetMaxIdleConns(dbProperties.MaxIdleConns()) - db.SetConnMaxLifetime(dbProperties.ConnMaxLifetime()) - } - return db, nil -} - func init() { registerDrivers() } diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 429617b10..140f03569 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -27,21 +27,45 @@ import ( ) type KeyInternalAPI interface { + SyncKeyAPI + ClientKeyAPI + FederationKeyAPI + UserKeyAPI + // SetUserAPI assigns a user API to query when extracting device names. - SetUserAPI(i userapi.UserInternalAPI) - // InputDeviceListUpdate from a federated server EDU - InputDeviceListUpdate(ctx context.Context, req *InputDeviceListUpdateRequest, res *InputDeviceListUpdateResponse) + SetUserAPI(i userapi.KeyserverUserAPI) +} + +// API functions required by the clientapi +type ClientKeyAPI interface { + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) - // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) - PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) + // PerformClaimKeys claims one-time keys for use in pre-key messages + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) +} + +// API functions required by the userapi +type UserKeyAPI interface { + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) + PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) +} + +// API functions required by the syncapi +type SyncKeyAPI interface { QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) - QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) +} + +type FederationKeyAPI interface { + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) + // InputDeviceListUpdate from a federated server EDU + InputDeviceListUpdate(ctx context.Context, req *InputDeviceListUpdateRequest, res *InputDeviceListUpdateResponse) + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) } // KeyError is returned if there was a problem performing/querying the server diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index e556f44b0..be71e5750 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -38,12 +38,12 @@ type KeyInternalAPI struct { DB storage.Database ThisServer gomatrixserverlib.ServerName FedClient fedsenderapi.FederationClient - UserAPI userapi.UserInternalAPI + UserAPI userapi.KeyserverUserAPI Producer *producers.KeyChange Updater *DeviceListUpdater } -func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) { +func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { a.UserAPI = i } diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index f50789b82..abce81582 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -60,7 +60,7 @@ type httpKeyInternalAPI struct { httpClient *http.Client } -func (h *httpKeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) { +func (h *httpKeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { // no-op: doesn't need it } func (h *httpKeyInternalAPI) InputDeviceListUpdate( diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index c557dfbaa..007a48a55 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -41,7 +41,7 @@ func NewInternalAPI( ) api.KeyInternalAPI { js, _ := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - db, err := storage.NewDatabase(&cfg.Database) + db, err := storage.NewDatabase(base, &cfg.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to key server database") } diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index d4c7e2cc7..b8f70acf8 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -18,13 +18,14 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/keyserver/storage/shared" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewDatabase creates a new sync server database -func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { var err error - db, err := sqlutil.Open(dbProperties) + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) if err != nil { return nil, err } @@ -63,7 +64,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) } d := &shared.Database{ DB: db, - Writer: sqlutil.NewDummyWriter(), + Writer: writer, OneTimeKeysTable: otk, DeviceKeysTable: dk, KeyChangesTable: kc, diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index 84d4cdf55..aeea9eac6 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -18,11 +18,12 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) -func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, err := sqlutil.Open(dbProperties) +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err } @@ -62,7 +63,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) } d := &shared.Database{ DB: db, - Writer: sqlutil.NewExclusiveWriter(), + Writer: writer, OneTimeKeysTable: otk, DeviceKeysTable: dk, KeyChangesTable: kc, diff --git a/keyserver/storage/storage.go b/keyserver/storage/storage.go index 742e8463a..ab6a35401 100644 --- a/keyserver/storage/storage.go +++ b/keyserver/storage/storage.go @@ -22,17 +22,18 @@ import ( "github.com/matrix-org/dendrite/keyserver/storage/postgres" "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) + return sqlite3.NewDatabase(base, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties) + return postgres.NewDatabase(base, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 84d2098ad..9940eac60 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -22,7 +22,7 @@ func MustCreateDatabase(t *testing.T) (Database, func()) { log.Fatal(err) } t.Logf("Database %s", tmpfile.Name()) - db, err := NewDatabase(&config.DatabaseOptions{ + db, err := NewDatabase(nil, &config.DatabaseOptions{ ConnectionString: config.DataSource(fmt.Sprintf("file://%s", tmpfile.Name())), }) if err != nil { diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go index 8b31bfd01..75c9053e8 100644 --- a/keyserver/storage/storage_wasm.go +++ b/keyserver/storage/storage_wasm.go @@ -18,13 +18,14 @@ import ( "fmt" "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) -func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) + return sqlite3.NewDatabase(base, dbProperties) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index e5daf480d..4792c996d 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -15,10 +15,9 @@ package mediaapi import ( - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/mediaapi/routing" "github.com/matrix-org/dendrite/mediaapi/storage" - "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/base" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -26,18 +25,19 @@ import ( // AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component. func AddPublicRoutes( - router *mux.Router, - cfg *config.MediaAPI, - rateLimit *config.RateLimiting, - userAPI userapi.UserInternalAPI, + base *base.BaseDendrite, + userAPI userapi.MediaUserAPI, client *gomatrixserverlib.Client, ) { - mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database) + cfg := &base.Cfg.MediaAPI + rateCfg := &base.Cfg.ClientAPI.RateLimiting + + mediaDB, err := storage.NewMediaAPIDatasource(base, &cfg.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to media db") } routing.Setup( - router, cfg, rateLimit, mediaDB, userAPI, client, + base.PublicMediaAPIMux, cfg, rateCfg, mediaDB, userAPI, client, ) } diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 5f22a9461..10b25a5cd 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -551,7 +551,7 @@ func (r *downloadRequest) getRemoteFile( // If we do not have a record, we need to fetch the remote file first and then respond from the local file err := r.fetchRemoteFileAndStoreMetadata( ctx, client, - cfg.AbsBasePath, *cfg.MaxFileSizeBytes, db, + cfg.AbsBasePath, cfg.MaxFileSizeBytes, db, cfg.ThumbnailSizes, activeThumbnailGeneration, cfg.MaxThumbnailGenerators, ) diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 0e1583991..76f07415b 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -35,7 +35,7 @@ import ( // configResponse is the response to GET /_matrix/media/r0/config // https://matrix.org/docs/spec/client_server/latest#get-matrix-media-r0-config type configResponse struct { - UploadSize config.FileSizeBytes `json:"m.upload.size"` + UploadSize *config.FileSizeBytes `json:"m.upload.size"` } // Setup registers the media API HTTP handlers @@ -48,7 +48,7 @@ func Setup( cfg *config.MediaAPI, rateLimit *config.RateLimiting, db storage.Database, - userAPI userapi.UserInternalAPI, + userAPI userapi.MediaUserAPI, client *gomatrixserverlib.Client, ) { rateLimits := httputil.NewRateLimits(rateLimit) @@ -73,9 +73,13 @@ func Setup( if r := rateLimits.Limit(req); r != nil { return *r } + respondSize := &cfg.MaxFileSizeBytes + if cfg.MaxFileSizeBytes == 0 { + respondSize = nil + } return util.JSONResponse{ Code: http.StatusOK, - JSON: configResponse{UploadSize: *cfg.MaxFileSizeBytes}, + JSON: configResponse{UploadSize: respondSize}, } }) diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 972c52af0..2175648ea 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -90,7 +90,7 @@ func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI, dev *usera Logger: util.GetLogger(req.Context()).WithField("Origin", cfg.Matrix.ServerName), } - if resErr := r.Validate(*cfg.MaxFileSizeBytes); resErr != nil { + if resErr := r.Validate(cfg.MaxFileSizeBytes); resErr != nil { return nil, resErr } @@ -148,20 +148,20 @@ func (r *uploadRequest) doUpload( // r.storeFileAndMetadata(ctx, tmpDir, ...) // before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of // nested function to guarantee either storage or cleanup. - if *cfg.MaxFileSizeBytes > 0 { - if *cfg.MaxFileSizeBytes+1 <= 0 { + if cfg.MaxFileSizeBytes > 0 { + if cfg.MaxFileSizeBytes+1 <= 0 { r.Logger.WithFields(log.Fields{ - "MaxFileSizeBytes": *cfg.MaxFileSizeBytes, + "MaxFileSizeBytes": cfg.MaxFileSizeBytes, }).Warnf("Configured MaxFileSizeBytes overflows int64, defaulting to %d bytes", config.DefaultMaxFileSizeBytes) - cfg.MaxFileSizeBytes = &config.DefaultMaxFileSizeBytes + cfg.MaxFileSizeBytes = config.DefaultMaxFileSizeBytes } - reqReader = io.LimitReader(reqReader, int64(*cfg.MaxFileSizeBytes)+1) + reqReader = io.LimitReader(reqReader, int64(cfg.MaxFileSizeBytes)+1) } hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath) if err != nil { r.Logger.WithError(err).WithFields(log.Fields{ - "MaxFileSizeBytes": *cfg.MaxFileSizeBytes, + "MaxFileSizeBytes": cfg.MaxFileSizeBytes, }).Warn("Error while transferring file") return &util.JSONResponse{ Code: http.StatusBadRequest, @@ -170,9 +170,9 @@ func (r *uploadRequest) doUpload( } // Check if temp file size exceeds max file size configuration - if *cfg.MaxFileSizeBytes > 0 && bytesWritten > types.FileSizeBytes(*cfg.MaxFileSizeBytes) { + if cfg.MaxFileSizeBytes > 0 && bytesWritten > types.FileSizeBytes(cfg.MaxFileSizeBytes) { fileutils.RemoveDir(tmpDir, r.Logger) // delete temp file - return requestEntityTooLargeJSONResponse(*cfg.MaxFileSizeBytes) + return requestEntityTooLargeJSONResponse(cfg.MaxFileSizeBytes) } // Look up the media by the file hash. If we already have the file but under a diff --git a/mediaapi/routing/upload_test.go b/mediaapi/routing/upload_test.go index b2c2f5a44..420d0eba9 100644 --- a/mediaapi/routing/upload_test.go +++ b/mediaapi/routing/upload_test.go @@ -36,12 +36,11 @@ func Test_uploadRequest_doUpload(t *testing.T) { } maxSize := config.FileSizeBytes(8) - unlimitedSize := config.FileSizeBytes(0) logger := log.New().WithField("mediaapi", "test") testdataPath := filepath.Join(wd, "./testdata") cfg := &config.MediaAPI{ - MaxFileSizeBytes: &maxSize, + MaxFileSizeBytes: maxSize, BasePath: config.Path(testdataPath), AbsBasePath: config.Path(testdataPath), DynamicThumbnails: false, @@ -51,7 +50,7 @@ func Test_uploadRequest_doUpload(t *testing.T) { _ = os.Mkdir(testdataPath, os.ModePerm) defer fileutils.RemoveDir(types.Path(testdataPath), nil) - db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{ + db, err := storage.NewMediaAPIDatasource(nil, &config.DatabaseOptions{ ConnectionString: "file::memory:?cache=shared", MaxOpenConnections: 100, MaxIdleConnections: 2, @@ -124,7 +123,7 @@ func Test_uploadRequest_doUpload(t *testing.T) { ctx: context.Background(), reqReader: strings.NewReader("test test test"), cfg: &config.MediaAPI{ - MaxFileSizeBytes: &unlimitedSize, + MaxFileSizeBytes: config.FileSizeBytes(0), BasePath: config.Path(testdataPath), AbsBasePath: config.Path(testdataPath), DynamicThumbnails: false, diff --git a/mediaapi/storage/postgres/mediaapi.go b/mediaapi/storage/postgres/mediaapi.go index ea70e575b..30ec64f84 100644 --- a/mediaapi/storage/postgres/mediaapi.go +++ b/mediaapi/storage/postgres/mediaapi.go @@ -20,12 +20,13 @@ import ( _ "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewDatabase opens a postgres database. -func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, err := sqlutil.Open(dbProperties) +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) if err != nil { return nil, err } @@ -41,6 +42,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) MediaRepository: mediaRepo, Thumbnails: thumbnails, DB: db, - Writer: sqlutil.NewExclusiveWriter(), + Writer: writer, }, nil } diff --git a/mediaapi/storage/sqlite3/mediaapi.go b/mediaapi/storage/sqlite3/mediaapi.go index abf329367..c0ab10e9f 100644 --- a/mediaapi/storage/sqlite3/mediaapi.go +++ b/mediaapi/storage/sqlite3/mediaapi.go @@ -19,12 +19,13 @@ import ( // Import the postgres database driver. "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewDatabase opens a SQLIte database. -func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, err := sqlutil.Open(dbProperties) +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err } @@ -40,6 +41,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) MediaRepository: mediaRepo, Thumbnails: thumbnails, DB: db, - Writer: sqlutil.NewExclusiveWriter(), + Writer: writer, }, nil } diff --git a/mediaapi/storage/storage.go b/mediaapi/storage/storage.go index baa242e57..f673ae7e6 100644 --- a/mediaapi/storage/storage.go +++ b/mediaapi/storage/storage.go @@ -22,16 +22,17 @@ import ( "github.com/matrix-org/dendrite/mediaapi/storage/postgres" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewMediaAPIDatasource opens a database connection. -func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) { +func NewMediaAPIDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) + return sqlite3.NewDatabase(base, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties) + return postgres.NewDatabase(base, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/mediaapi/storage/storage_test.go b/mediaapi/storage/storage_test.go index fa88cd8e7..81f0a5d24 100644 --- a/mediaapi/storage/storage_test.go +++ b/mediaapi/storage/storage_test.go @@ -13,7 +13,7 @@ import ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{ + db, err := storage.NewMediaAPIDatasource(nil, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { diff --git a/mediaapi/storage/storage_wasm.go b/mediaapi/storage/storage_wasm.go index f67f9d5e1..41e4a28c0 100644 --- a/mediaapi/storage/storage_wasm.go +++ b/mediaapi/storage/storage_wasm.go @@ -18,14 +18,15 @@ import ( "fmt" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // Open opens a postgres database. -func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) { +func NewMediaAPIDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) + return sqlite3.NewDatabase(base, dbProperties) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index baab27751..37892a44a 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -59,18 +59,6 @@ type GetAliasesForRoomIDResponse struct { Aliases []string `json:"aliases"` } -// GetCreatorIDForAliasRequest is a request to GetCreatorIDForAlias -type GetCreatorIDForAliasRequest struct { - // The alias we want to find the creator of - Alias string `json:"alias"` -} - -// GetCreatorIDForAliasResponse is a response to GetCreatorIDForAlias -type GetCreatorIDForAliasResponse struct { - // The user ID of the alias creator - UserID string `json:"user_id"` -} - // RemoveRoomAliasRequest is a request to RemoveRoomAlias type RemoveRoomAliasRequest struct { // ID of the user removing the alias diff --git a/roomserver/api/api.go b/roomserver/api/api.go index f0ca8a615..cbb4cebca 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -12,219 +12,179 @@ import ( // RoomserverInputAPI is used to write events to the room server. type RoomserverInternalAPI interface { + SyncRoomserverAPI + AppserviceRoomserverAPI + ClientRoomserverAPI + UserRoomserverAPI + FederationRoomserverAPI + // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs - SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) - SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) - SetUserAPI(userAPI userapi.UserInternalAPI) + SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) + SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) + SetUserAPI(userAPI userapi.RoomserverUserAPI) + // QueryAuthChain returns the entire auth chain for the event IDs given. + // The response includes the events in the request. + // Omits without error for any missing auth events. There will be no duplicates. + // Used in MSC2836. + QueryAuthChain( + ctx context.Context, + req *QueryAuthChainRequest, + res *QueryAuthChainResponse, + ) error +} + +type InputRoomEventsAPI interface { InputRoomEvents( ctx context.Context, - request *InputRoomEventsRequest, - response *InputRoomEventsResponse, + req *InputRoomEventsRequest, + res *InputRoomEventsResponse, ) +} - PerformInvite( +// Query the latest events and state for a room from the room server. +type QueryLatestEventsAndStateAPI interface { + QueryLatestEventsAndState(ctx context.Context, req *QueryLatestEventsAndStateRequest, res *QueryLatestEventsAndStateResponse) error +} + +// QueryBulkStateContent does a bulk query for state event content in the given rooms. +type QueryBulkStateContentAPI interface { + QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error +} + +type QueryEventsAPI interface { + // Query a list of events by event ID. + QueryEventsByID( ctx context.Context, - req *PerformInviteRequest, - res *PerformInviteResponse, + req *QueryEventsByIDRequest, + res *QueryEventsByIDResponse, ) error + // QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from + // the response. + QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error +} - PerformJoin( +// API functions required by the syncapi +type SyncRoomserverAPI interface { + QueryLatestEventsAndStateAPI + QueryBulkStateContentAPI + // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. + QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error + // Query a list of events by event ID. + QueryEventsByID( ctx context.Context, - req *PerformJoinRequest, - res *PerformJoinResponse, - ) - - PerformLeave( - ctx context.Context, - req *PerformLeaveRequest, - res *PerformLeaveResponse, + req *QueryEventsByIDRequest, + res *QueryEventsByIDResponse, ) error - - PerformPeek( + // Query the membership event for an user for a room. + QueryMembershipForUser( ctx context.Context, - req *PerformPeekRequest, - res *PerformPeekResponse, - ) - - PerformUnpeek( - ctx context.Context, - req *PerformUnpeekRequest, - res *PerformUnpeekResponse, - ) - - PerformPublish( - ctx context.Context, - req *PerformPublishRequest, - res *PerformPublishResponse, - ) - - PerformInboundPeek( - ctx context.Context, - req *PerformInboundPeekRequest, - res *PerformInboundPeekResponse, - ) error - - PerformAdminEvacuateRoom( - ctx context.Context, - req *PerformAdminEvacuateRoomRequest, - res *PerformAdminEvacuateRoomResponse, - ) - - QueryPublishedRooms( - ctx context.Context, - req *QueryPublishedRoomsRequest, - res *QueryPublishedRoomsResponse, - ) error - - // Query the latest events and state for a room from the room server. - QueryLatestEventsAndState( - ctx context.Context, - request *QueryLatestEventsAndStateRequest, - response *QueryLatestEventsAndStateResponse, + req *QueryMembershipForUserRequest, + res *QueryMembershipForUserResponse, ) error // Query the state after a list of events in a room from the room server. QueryStateAfterEvents( ctx context.Context, - request *QueryStateAfterEventsRequest, - response *QueryStateAfterEventsResponse, + req *QueryStateAfterEventsRequest, + res *QueryStateAfterEventsResponse, ) error - // Query a list of events by event ID. - QueryEventsByID( - ctx context.Context, - request *QueryEventsByIDRequest, - response *QueryEventsByIDResponse, - ) error - - // Query the membership event for an user for a room. - QueryMembershipForUser( - ctx context.Context, - request *QueryMembershipForUserRequest, - response *QueryMembershipForUserResponse, - ) error - - // Query a list of membership events for a room - QueryMembershipsForRoom( - ctx context.Context, - request *QueryMembershipsForRoomRequest, - response *QueryMembershipsForRoomResponse, - ) error - - // Query if we think we're still in a room. - QueryServerJoinedToRoom( - ctx context.Context, - request *QueryServerJoinedToRoomRequest, - response *QueryServerJoinedToRoomResponse, - ) error - - // Query whether a server is allowed to see an event - QueryServerAllowedToSeeEvent( - ctx context.Context, - request *QueryServerAllowedToSeeEventRequest, - response *QueryServerAllowedToSeeEventResponse, - ) error - - // Query missing events for a room from roomserver - QueryMissingEvents( - ctx context.Context, - request *QueryMissingEventsRequest, - response *QueryMissingEventsResponse, - ) error - - // Query to get state and auth chain for a (potentially hypothetical) event. - // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate - // the state and auth chain to return. - QueryStateAndAuthChain( - ctx context.Context, - request *QueryStateAndAuthChainRequest, - response *QueryStateAndAuthChainResponse, - ) error - - // QueryAuthChain returns the entire auth chain for the event IDs given. - // The response includes the events in the request. - // Omits without error for any missing auth events. There will be no duplicates. - QueryAuthChain( - ctx context.Context, - request *QueryAuthChainRequest, - response *QueryAuthChainResponse, - ) error - - // QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from - // the response. - QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error - // QueryRoomsForUser retrieves a list of room IDs matching the given query. - QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error - // QueryBulkStateContent does a bulk query for state event content in the given rooms. - QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error - // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. - QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error - // QueryKnownUsers returns a list of users that we know about from our joined rooms. - QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error - // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. - QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error - // Query a given amount (or less) of events prior to a given set of events. PerformBackfill( ctx context.Context, - request *PerformBackfillRequest, - response *PerformBackfillResponse, + req *PerformBackfillRequest, + res *PerformBackfillResponse, ) error +} - // PerformForget forgets a rooms history for a specific user - PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error - - // PerformRoomUpgrade upgrades a room to a newer version - PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) - - // Asks for the default room version as preferred by the server. - QueryRoomVersionCapabilities( +type AppserviceRoomserverAPI interface { + // Query a list of events by event ID. + QueryEventsByID( ctx context.Context, - request *QueryRoomVersionCapabilitiesRequest, - response *QueryRoomVersionCapabilitiesResponse, + req *QueryEventsByIDRequest, + res *QueryEventsByIDResponse, ) error - - // Asks for the room version for a given room. - QueryRoomVersionForRoom( + // Query a list of membership events for a room + QueryMembershipsForRoom( ctx context.Context, - request *QueryRoomVersionForRoomRequest, - response *QueryRoomVersionForRoomResponse, + req *QueryMembershipsForRoomRequest, + res *QueryMembershipsForRoomResponse, ) error - - // Set a room alias - SetRoomAlias( - ctx context.Context, - req *SetRoomAliasRequest, - response *SetRoomAliasResponse, - ) error - - // Get the room ID for an alias - GetRoomIDForAlias( - ctx context.Context, - req *GetRoomIDForAliasRequest, - response *GetRoomIDForAliasResponse, - ) error - // Get all known aliases for a room ID GetAliasesForRoomID( ctx context.Context, req *GetAliasesForRoomIDRequest, - response *GetAliasesForRoomIDResponse, - ) error - - // Get the user ID of the creator of an alias - GetCreatorIDForAlias( - ctx context.Context, - req *GetCreatorIDForAliasRequest, - response *GetCreatorIDForAliasResponse, - ) error - - // Remove a room alias - RemoveRoomAlias( - ctx context.Context, - req *RemoveRoomAliasRequest, - response *RemoveRoomAliasResponse, + res *GetAliasesForRoomIDResponse, ) error } + +type ClientRoomserverAPI interface { + InputRoomEventsAPI + QueryLatestEventsAndStateAPI + QueryBulkStateContentAPI + QueryEventsAPI + QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error + QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error + QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error + QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error + // QueryKnownUsers returns a list of users that we know about from our joined rooms. + QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error + QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error + QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error + QueryRoomVersionCapabilities(ctx context.Context, req *QueryRoomVersionCapabilitiesRequest, res *QueryRoomVersionCapabilitiesResponse) error + + GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error + GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error + + // PerformRoomUpgrade upgrades a room to a newer version + PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) + PerformAdminEvacuateRoom( + ctx context.Context, + req *PerformAdminEvacuateRoomRequest, + res *PerformAdminEvacuateRoomResponse, + ) + PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) + PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) + PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error + PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) + PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error + PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) + // PerformForget forgets a rooms history for a specific user + PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error + SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error + RemoveRoomAlias(ctx context.Context, req *RemoveRoomAliasRequest, res *RemoveRoomAliasResponse) error +} + +type UserRoomserverAPI interface { + QueryLatestEventsAndStateAPI + QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error + QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error +} + +type FederationRoomserverAPI interface { + InputRoomEventsAPI + QueryLatestEventsAndStateAPI + QueryBulkStateContentAPI + // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. + QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error + QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error + GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error + QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error + // Query to get state and auth chain for a (potentially hypothetical) event. + // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate + // the state and auth chain to return. + QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error + // Query if we think we're still in a room. + QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error + QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error + // Query missing events for a room from roomserver + QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error + // Query whether a server is allowed to see an event + QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error + PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error + PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error + // Query a given amount (or less) of events prior to a given set of events. + PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error +} diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 61c06e886..711324644 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -19,15 +19,15 @@ type RoomserverInternalAPITrace struct { Impl RoomserverInternalAPI } -func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) { +func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { t.Impl.SetFederationAPI(fsAPI, keyRing) } -func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { +func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { t.Impl.SetAppserviceAPI(asAPI) } -func (t *RoomserverInternalAPITrace) SetUserAPI(userAPI userapi.UserInternalAPI) { +func (t *RoomserverInternalAPITrace) SetUserAPI(userAPI userapi.RoomserverUserAPI) { t.Impl.SetUserAPI(userAPI) } @@ -293,16 +293,6 @@ func (t *RoomserverInternalAPITrace) GetAliasesForRoomID( return err } -func (t *RoomserverInternalAPITrace) GetCreatorIDForAlias( - ctx context.Context, - req *GetCreatorIDForAliasRequest, - res *GetCreatorIDForAliasResponse, -) error { - err := t.Impl.GetCreatorIDForAlias(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("GetCreatorIDForAlias req=%+v res=%+v", js(req), js(res)) - return err -} - func (t *RoomserverInternalAPITrace) RemoveRoomAlias( ctx context.Context, req *RemoveRoomAliasRequest, diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 5491d36b3..344e9b079 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -16,7 +16,6 @@ package api import ( "context" - "fmt" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -24,7 +23,7 @@ import ( // SendEvents to the roomserver The events are written with KindNew. func SendEvents( - ctx context.Context, rsAPI RoomserverInternalAPI, + ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, events []*gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, @@ -47,7 +46,7 @@ func SendEvents( // with the state at the event as KindOutlier before it. Will not send any event that is // marked as `true` in haveEventIDs. func SendEventWithState( - ctx context.Context, rsAPI RoomserverInternalAPI, kind Kind, + ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { @@ -83,7 +82,7 @@ func SendEventWithState( // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( - ctx context.Context, rsAPI RoomserverInternalAPI, + ctx context.Context, rsAPI InputRoomEventsAPI, ires []InputRoomEvent, async bool, ) error { request := InputRoomEventsRequest{ @@ -95,37 +94,8 @@ func SendInputRoomEvents( return response.Err() } -// SendInvite event to the roomserver. -// This should only be needed for invite events that occur outside of a known room. -// If we are in the room then the event should be sent using the SendEvents method. -func SendInvite( - ctx context.Context, - rsAPI RoomserverInternalAPI, inviteEvent *gomatrixserverlib.HeaderedEvent, - inviteRoomState []gomatrixserverlib.InviteV2StrippedState, - sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, -) error { - // Start by sending the invite request into the roomserver. This will - // trigger the federation request amongst other things if needed. - request := &PerformInviteRequest{ - Event: inviteEvent, - InviteRoomState: inviteRoomState, - RoomVersion: inviteEvent.RoomVersion, - SendAsServer: string(sendAsServer), - TransactionID: txnID, - } - response := &PerformInviteResponse{} - if err := rsAPI.PerformInvite(ctx, request, response); err != nil { - return fmt.Errorf("rsAPI.PerformInvite: %w", err) - } - if response.Error != nil { - return response.Error - } - - return nil -} - // GetEvent returns the event or nil, even on errors. -func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string) *gomatrixserverlib.HeaderedEvent { +func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, eventID string) *gomatrixserverlib.HeaderedEvent { var res QueryEventsByIDResponse err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{ EventIDs: []string{eventID}, @@ -141,7 +111,7 @@ func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string) } // GetStateEvent returns the current state event in the room or nil. -func GetStateEvent(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent { +func GetStateEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent { var res QueryCurrentStateResponse err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{ RoomID: roomID, @@ -159,7 +129,7 @@ func GetStateEvent(ctx context.Context, rsAPI RoomserverInternalAPI, roomID stri } // IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs. -func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool { +func IsServerBannedFromRoom(ctx context.Context, rsAPI FederationRoomserverAPI, roomID string, serverName gomatrixserverlib.ServerName) bool { req := &QueryServerBannedFromRoomRequest{ ServerName: serverName, RoomID: roomID, @@ -175,7 +145,7 @@ func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, ro // PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the // published room directory. // due to lots of switches -func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { +func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI QueryBulkStateContentAPI) ([]gomatrixserverlib.PublicRoom, error) { avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 02fc4a5a7..f47ae47fe 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -41,9 +41,6 @@ type RoomserverInternalAPIDatabase interface { // Look up all aliases referring to a given room ID. // Returns an error if there was a problem talking to the database. GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) - // Get the user ID of the creator of an alias. - // Returns an error if there was a problem talking to the database. - GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) // Remove a given room alias. // Returns an error if there was a problem talking to the database. RemoveRoomAlias(ctx context.Context, alias string) error @@ -134,22 +131,6 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID( return nil } -// GetCreatorIDForAlias implements alias.RoomserverInternalAPI -func (r *RoomserverInternalAPI) GetCreatorIDForAlias( - ctx context.Context, - request *api.GetCreatorIDForAliasRequest, - response *api.GetCreatorIDForAliasResponse, -) error { - // Look up the aliases in the database for the given RoomID - creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias) - if err != nil { - return err - } - - response.UserID = creatorID - return nil -} - // RemoveRoomAlias implements alias.RoomserverInternalAPI func (r *RoomserverInternalAPI) RemoveRoomAlias( ctx context.Context, diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 267cd4099..afef52da4 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -43,8 +43,8 @@ type RoomserverInternalAPI struct { ServerName gomatrixserverlib.ServerName KeyRing gomatrixserverlib.JSONVerifier ServerACLs *acls.ServerACLs - fsAPI fsAPI.FederationInternalAPI - asAPI asAPI.AppServiceQueryAPI + fsAPI fsAPI.RoomserverFederationAPI + asAPI asAPI.AppServiceInternalAPI NATSClient *nats.Conn JetStream nats.JetStreamContext Durable string @@ -87,7 +87,7 @@ func NewRoomserverAPI( // SetFederationInputAPI passes in a federation input API reference so that we can // avoid the chicken-and-egg problem of both the roomserver input API and the // federation input API being interdependent. -func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) { +func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { r.fsAPI = fsAPI r.KeyRing = keyRing @@ -177,11 +177,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA } } -func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) { +func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { r.Leaver.UserAPI = userAPI } -func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { +func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { r.asAPI = asAPI } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 1fea6ef06..600994c5a 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -82,7 +82,7 @@ type Inputer struct { JetStream nats.JetStreamContext Durable nats.SubOpt ServerName gomatrixserverlib.ServerName - FSAPI fedapi.FederationInternalAPI + FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs InputRoomEventTopic string @@ -202,7 +202,7 @@ func (w *worker) _next() { return } - case context.DeadlineExceeded: + case context.DeadlineExceeded, context.Canceled: // The context exceeded, so we've been waiting for more than a // minute for activity in this room. At this point we will shut // down the subscriber to free up resources. It'll get started diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 7e58ef9d0..9ad8b0422 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -316,40 +316,30 @@ func (u *latestEventsUpdater) calculateLatest( // Then let's see if any of the existing forward extremities now // have entries in the previous events table. If they do then we // will no longer include them as forward extremities. - existingPrevs := make(map[string]struct{}) - for _, l := range existingRefs { + for k, l := range existingRefs { referenced, err := u.updater.IsReferenced(l.EventReference) if err != nil { return false, fmt.Errorf("u.updater.IsReferenced: %w", err) } else if referenced { - existingPrevs[l.EventID] = struct{}{} + delete(existingRefs, k) } } - // Include our new event in the extremities. - newLatest := []types.StateAtEventAndReference{newStateAndRef} + // Start off with our new unreferenced event. We're reusing the backing + // array here rather than allocating a new one. + u.latest = append(u.latest[:0], newStateAndRef) - // Then run through and see if the other extremities are still valid. - // If our new event references them then they are no longer good - // candidates. + // If our new event references any of the existing forward extremities + // then they are no longer forward extremities, so remove them. for _, prevEventID := range newEvent.PrevEventIDs() { delete(existingRefs, prevEventID) } - // Ensure that we don't add any candidate forward extremities from - // the old set that are, themselves, referenced by the old set of - // forward extremities. This shouldn't happen but guards against - // the possibility anyway. - for prevEventID := range existingPrevs { - delete(existingRefs, prevEventID) - } - // Then re-add any old extremities that are still valid after all. for _, old := range existingRefs { - newLatest = append(newLatest, *old) + u.latest = append(u.latest, *old) } - u.latest = newLatest return true, nil } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 2c958335d..9c70076c2 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -44,7 +44,7 @@ type missingStateReq struct { roomInfo *types.RoomInfo inputer *Inputer keys gomatrixserverlib.JSONVerifier - federation fedapi.FederationInternalAPI + federation fedapi.RoomserverFederationAPI roomsMu *internal.MutexByRoom servers []gomatrixserverlib.ServerName hadEvents map[string]bool diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go index 81c86ae38..5d34842bf 100644 --- a/roomserver/internal/input/input_test.go +++ b/roomserver/internal/input/input_test.go @@ -53,6 +53,7 @@ func TestSingleTransactionOnInput(t *testing.T) { t.Fatal(err) } db, err := storage.Open( + nil, &config.DatabaseOptions{ ConnectionString: "", MaxOpenConnections: 1, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 081f694a1..1bc4c75ce 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -38,7 +38,7 @@ const maxBackfillServers = 5 type Backfiller struct { ServerName gomatrixserverlib.ServerName DB storage.Database - FSAPI federationAPI.FederationInternalAPI + FSAPI federationAPI.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier // The servers which should be preferred above other servers when backfilling @@ -228,7 +228,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { db storage.Database - fsAPI federationAPI.FederationInternalAPI + fsAPI federationAPI.RoomserverFederationAPI thisServer gomatrixserverlib.ServerName preferServer map[gomatrixserverlib.ServerName]bool bwExtrems map[string][]string @@ -240,7 +240,7 @@ type backfillRequester struct { } func newBackfillRequester( - db storage.Database, fsAPI federationAPI.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, + db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, thisServer gomatrixserverlib.ServerName, bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, ) *backfillRequester { preferServer := make(map[gomatrixserverlib.ServerName]bool) diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 6111372d8..b0148a314 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -35,7 +35,7 @@ import ( type Inviter struct { DB storage.Database Cfg *config.RoomServer - FSAPI federationAPI.FederationInternalAPI + FSAPI federationAPI.RoomserverFederationAPI Inputer *input.Inputer } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index a40f66d21..61a0206ef 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -38,7 +38,7 @@ import ( type Joiner struct { ServerName gomatrixserverlib.ServerName Cfg *config.RoomServer - FSAPI fsAPI.FederationInternalAPI + FSAPI fsAPI.RoomserverFederationAPI RSAPI rsAPI.RoomserverInternalAPI DB storage.Database diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 5b4cd3c6f..c5b62ac00 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -37,8 +37,8 @@ import ( type Leaver struct { Cfg *config.RoomServer DB storage.Database - FSAPI fsAPI.FederationInternalAPI - UserAPI userapi.UserInternalAPI + FSAPI fsAPI.RoomserverFederationAPI + UserAPI userapi.RoomserverUserAPI Inputer *input.Inputer } diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 6a2c329b9..45e63888d 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -33,7 +33,7 @@ import ( type Peeker struct { ServerName gomatrixserverlib.ServerName Cfg *config.RoomServer - FSAPI fsAPI.FederationInternalAPI + FSAPI fsAPI.RoomserverFederationAPI DB storage.Database Inputer *input.Inputer diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 16b4eeaed..1057499cb 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -30,7 +30,7 @@ import ( type Unpeeker struct { ServerName gomatrixserverlib.ServerName Cfg *config.RoomServer - FSAPI fsAPI.FederationInternalAPI + FSAPI fsAPI.RoomserverFederationAPI DB storage.Database Inputer *input.Inputer diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 3b29001e9..09358001b 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -87,15 +87,15 @@ func NewRoomserverClient( } // SetFederationInputAPI no-ops in HTTP client mode as there is no chicken/egg scenario -func (h *httpRoomserverInternalAPI) SetFederationAPI(fsAPI fsInputAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) { +func (h *httpRoomserverInternalAPI) SetFederationAPI(fsAPI fsInputAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { } // SetAppserviceAPI no-ops in HTTP client mode as there is no chicken/egg scenario -func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { +func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { } // SetUserAPI no-ops in HTTP client mode as there is no chicken/egg scenario -func (h *httpRoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) { +func (h *httpRoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { } // SetRoomAlias implements RoomserverAliasAPI @@ -137,19 +137,6 @@ func (h *httpRoomserverInternalAPI) GetAliasesForRoomID( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } -// GetCreatorIDForAlias implements RoomserverAliasAPI -func (h *httpRoomserverInternalAPI) GetCreatorIDForAlias( - ctx context.Context, - request *api.GetCreatorIDForAliasRequest, - response *api.GetCreatorIDForAliasResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetCreatorIDForAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverGetCreatorIDForAliasPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - // RemoveRoomAlias implements RoomserverAliasAPI func (h *httpRoomserverInternalAPI) RemoveRoomAlias( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index c5159a63c..9042e341b 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -353,20 +353,6 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle( - RoomserverGetCreatorIDForAliasPath, - httputil.MakeInternalAPI("GetCreatorIDForAlias", func(req *http.Request) util.JSONResponse { - var request api.GetCreatorIDForAliasRequest - var response api.GetCreatorIDForAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.GetCreatorIDForAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) internalAPIMux.Handle( RoomserverGetAliasesForRoomIDPath, httputil.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 36e3c5269..46261eb3e 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -45,7 +45,7 @@ func NewInternalAPI( perspectiveServerNames = append(perspectiveServerNames, kp.ServerName) } - roomserverDB, err := storage.Open(&cfg.Database, base.Caches) + roomserverDB, err := storage.Open(base, &cfg.Database, base.Caches) if err != nil { logrus.WithError(err).Panicf("failed to connect to room server db") } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index b5e05c982..da8d25848 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) @@ -35,11 +36,11 @@ type Database struct { } // Open a postgres database. -func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { +func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database - var db *sql.DB var err error - if db, err = sqlutil.Open(dbProperties); err != nil { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) + if err != nil { return nil, fmt.Errorf("sqlutil.Open: %w", err) } @@ -59,7 +60,7 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) // Then prepare the statements. Now that the migrations have run, any columns referred // to in the database code should now exist. - if err := d.prepare(db, cache); err != nil { + if err := d.prepare(db, writer, cache); err != nil { return nil, err } @@ -110,7 +111,7 @@ func (d *Database) create(db *sql.DB) error { return nil } -func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { +func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.RoomServerCaches) error { eventStateKeys, err := prepareEventStateKeysTable(db) if err != nil { return err @@ -166,7 +167,7 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { d.Database = shared.Database{ DB: db, Cache: cache, - Writer: sqlutil.NewDummyWriter(), + Writer: writer, EventTypesTable: eventTypes, EventStateKeysTable: eventStateKeys, EventJSONTable: eventJSON, diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 325c253b5..e6cf1a53f 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -18,12 +18,14 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -34,12 +36,12 @@ type Database struct { } // Open a sqlite database. -func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { +func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database - var db *sql.DB var err error - if db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) + if err != nil { + return nil, fmt.Errorf("sqlutil.Open: %w", err) } //db.Exec("PRAGMA journal_mode=WAL;") @@ -49,7 +51,7 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) // cause the roomserver to be unresponsive to new events because something will // acquire the global mutex and never unlock it because it is waiting for a connection // which it will never obtain. - db.SetMaxOpenConns(20) + // db.SetMaxOpenConns(20) // Create the tables. if err := d.create(db); err != nil { @@ -67,7 +69,7 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) // Then prepare the statements. Now that the migrations have run, any columns referred // to in the database code should now exist. - if err := d.prepare(db, cache); err != nil { + if err := d.prepare(db, writer, cache); err != nil { return nil, err } @@ -118,7 +120,7 @@ func (d *Database) create(db *sql.DB) error { return nil } -func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { +func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.RoomServerCaches) error { eventStateKeys, err := prepareEventStateKeysTable(db) if err != nil { return err @@ -174,7 +176,7 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { d.Database = shared.Database{ DB: db, Cache: cache, - Writer: sqlutil.NewExclusiveWriter(), + Writer: writer, EventsTable: events, EventTypesTable: eventTypes, EventStateKeysTable: eventStateKeys, diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 9f98ea3ed..8a87b7d7c 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -23,16 +23,17 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // Open opens a database connection. -func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { +func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.Open(dbProperties, cache) + return sqlite3.Open(base, dbProperties, cache) case dbProperties.ConnectionString.IsPostgres(): - return postgres.Open(dbProperties, cache) + return postgres.Open(base, dbProperties, cache) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/roomserver/storage/storage_wasm.go b/roomserver/storage/storage_wasm.go index dfc374e6e..df5a56ac3 100644 --- a/roomserver/storage/storage_wasm.go +++ b/roomserver/storage/storage_wasm.go @@ -19,14 +19,15 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewPublicRoomsServerDatabase opens a database connection. -func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { +func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.Open(dbProperties, cache) + return sqlite3.Open(base, dbProperties, cache) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/setup/base/base.go b/setup/base/base.go index 4b771aa36..ef449cc35 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -17,6 +17,7 @@ package base import ( "context" "crypto/tls" + "database/sql" "fmt" "io" "net" @@ -32,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/atomic" @@ -40,7 +42,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/process" - userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/gorilla/mux" "github.com/kardianos/minwinsvc" @@ -81,6 +82,8 @@ type BaseDendrite struct { Cfg *config.Dendrite Caches *caching.Caches DNSCache *gomatrixserverlib.DNSCache + Database *sql.DB + DatabaseWriter sqlutil.Writer } const NoListener = "" @@ -93,6 +96,7 @@ type BaseDendriteOptions int const ( NoCacheMetrics BaseDendriteOptions = iota UseHTTPAPIs + PolylithMode ) // NewBaseDendrite creates a new instance to be used by a component. @@ -102,17 +106,21 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base platformSanityChecks() useHTTPAPIs := false cacheMetrics := true + isMonolith := true for _, opt := range options { switch opt { case NoCacheMetrics: cacheMetrics = false case UseHTTPAPIs: useHTTPAPIs = true + case PolylithMode: + isMonolith = false + useHTTPAPIs = true } } configErrors := &config.ConfigErrors{} - cfg.Verify(configErrors, componentName == "Monolith") // TODO: better way? + cfg.Verify(configErrors, isMonolith) if len(*configErrors) > 0 { for _, err := range *configErrors { logrus.Errorf("Configuration error: %s", err) @@ -185,6 +193,25 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base }, } + // If we're in monolith mode, we'll set up a global pool of database + // connections. A component is welcome to use this pool if they don't + // have a separate database config of their own. + var db *sql.DB + var writer sqlutil.Writer + if cfg.Global.DatabaseOptions.ConnectionString != "" { + if !isMonolith { + logrus.Panic("Using a global database connection pool is not supported in polylith deployments") + } + if cfg.Global.DatabaseOptions.ConnectionString.IsSQLite() { + logrus.Panic("Using a global database connection pool is not supported with SQLite databases") + } + writer = sqlutil.NewDummyWriter() + if db, err = sqlutil.Open(&cfg.Global.DatabaseOptions, writer); err != nil { + logrus.WithError(err).Panic("Failed to set up global database connections") + } + logrus.Debug("Using global database connection pool") + } + // Ideally we would only use SkipClean on routes which we know can allow '/' but due to // https://github.com/gorilla/mux/issues/460 we have to attach this at the top router. // When used in conjunction with UseEncodedPath() we get the behaviour we want when parsing @@ -214,6 +241,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base DendriteAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.DendriteAdminPathPrefix).Subrouter().UseEncodedPath(), SynapseAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.SynapseAdminPathPrefix).Subrouter().UseEncodedPath(), apiHttpClient: &apiClient, + Database: db, // set if monolith with global connection pool only + DatabaseWriter: writer, // set if monolith with global connection pool only } } @@ -222,8 +251,31 @@ func (b *BaseDendrite) Close() error { return b.tracerCloser.Close() } -// AppserviceHTTPClient returns the AppServiceQueryAPI for hitting the appservice component over HTTP. -func (b *BaseDendrite) AppserviceHTTPClient() appserviceAPI.AppServiceQueryAPI { +// DatabaseConnection assists in setting up a database connection. It accepts +// the database properties and a new writer for the given component. If we're +// running in monolith mode with a global connection pool configured then we +// will return that connection, along with the global writer, effectively +// ignoring the options provided. Otherwise we'll open a new database connection +// using the supplied options and writer. Note that it's possible for the pointer +// receiver to be nil here – that's deliberate as some of the unit tests don't +// have a BaseDendrite and just want a connection with the supplied config +// without any pooling stuff. +func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) { + if dbProperties.ConnectionString != "" || b == nil { + // Open a new database connection using the supplied config. + db, err := sqlutil.Open(dbProperties, writer) + return db, writer, err + } + if b.Database != nil && b.DatabaseWriter != nil { + // Ignore the supplied config and return the global pool and + // writer. + return b.Database, b.DatabaseWriter, nil + } + return nil, nil, fmt.Errorf("no database connections configured") +} + +// AppserviceHTTPClient returns the AppServiceInternalAPI for hitting the appservice component over HTTP. +func (b *BaseDendrite) AppserviceHTTPClient() appserviceAPI.AppServiceInternalAPI { a, err := asinthttp.NewAppserviceClient(b.Cfg.AppServiceURL(), b.apiHttpClient) if err != nil { logrus.WithError(err).Panic("CreateHTTPAppServiceAPIs failed") @@ -273,24 +325,6 @@ func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client { return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation) } -// CreateAccountsDB creates a new instance of the accounts database. Should only -// be called once per component. -func (b *BaseDendrite) CreateAccountsDB() userdb.Database { - db, err := userdb.NewUserAPIDatabase( - &b.Cfg.UserAPI.AccountDatabase, - b.Cfg.Global.ServerName, - b.Cfg.UserAPI.BCryptCost, - b.Cfg.UserAPI.OpenIDTokenLifetimeMS, - userapi.DefaultLoginTokenLifetime, - b.Cfg.Global.ServerNotices.LocalPart, - ) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to accounts db") - } - - return db -} - // CreateClient creates a new client (normally used for media fetch requests). // Should only be called once per component. func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client { @@ -301,6 +335,7 @@ func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client { } opts := []gomatrixserverlib.ClientOption{ gomatrixserverlib.WithSkipVerify(b.Cfg.FederationAPI.DisableTLSValidation), + gomatrixserverlib.WithWellKnownSRVLookups(true), } if b.Cfg.Global.DNSCache.Enabled { opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache)) diff --git a/setup/config/config.go b/setup/config/config.go index e03518e24..9b9000a62 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -78,6 +78,8 @@ type Dendrite struct { // Any information derived from the configuration options for later use. Derived Derived `yaml:"-"` + + IsMonolith bool `yaml:"-"` } // TODO: Kill Derived @@ -210,6 +212,7 @@ func loadConfig( ) (*Dendrite, error) { var c Dendrite c.Defaults(false) + c.IsMonolith = monolithic var err error if err = yaml.Unmarshal(configData, &c); err != nil { diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go index 3f4e1c917..d93b6ebe0 100644 --- a/setup/config/config_appservice.go +++ b/setup/config/config_appservice.go @@ -52,7 +52,9 @@ func (c *AppServiceAPI) Defaults(generate bool) { func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkURL(configErrs, "app_service_api.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "app_service_api.internal_api.bind", string(c.InternalAPI.Connect)) - checkNotEmpty(configErrs, "app_service_api.database.connection_string", string(c.Database.ConnectionString)) + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "app_service_api.database.connection_string", string(c.Database.ConnectionString)) + } } // ApplicationServiceNamespace is the namespace that a specific application diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 176334dd8..f62a23e1f 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -49,7 +49,9 @@ func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { if !isMonolith { checkURL(configErrs, "federation_api.external_api.listen", string(c.ExternalAPI.Listen)) } - checkNotEmpty(configErrs, "federation_api.database.connection_string", string(c.Database.ConnectionString)) + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "federation_api.database.connection_string", string(c.Database.ConnectionString)) + } } // The config for setting a proxy to use for server->server requests diff --git a/setup/config/config_global.go b/setup/config/config_global.go index c1650f077..9d4c1485e 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -34,6 +34,13 @@ type Global struct { // Defaults to 24 hours. KeyValidityPeriod time.Duration `yaml:"key_validity_period"` + // Global pool of database connections, which is used only in monolith mode. If a + // component does not specify any database options of its own, then this pool of + // connections will be used instead. This way we don't have to manage connection + // counts on a per-component basis, but can instead do it for the entire monolith. + // In a polylith deployment, this will be ignored. + DatabaseOptions DatabaseOptions `yaml:"database"` + // The server name to delegate server-server communications to, with optional port WellKnownServerName string `yaml:"well_known_server_name"` @@ -63,6 +70,9 @@ type Global struct { // ServerNotices configuration used for sending server notices ServerNotices ServerNotices `yaml:"server_notices"` + + // ReportStats configures opt-in anonymous stats reporting. + ReportStats ReportStats `yaml:"report_stats"` } func (c *Global) Defaults(generate bool) { @@ -79,6 +89,7 @@ func (c *Global) Defaults(generate bool) { c.DNSCache.Defaults() c.Sentry.Defaults() c.ServerNotices.Defaults(generate) + c.ReportStats.Defaults() } func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -90,6 +101,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { c.Sentry.Verify(configErrs, isMonolith) c.DNSCache.Verify(configErrs, isMonolith) c.ServerNotices.Verify(configErrs, isMonolith) + c.ReportStats.Verify(configErrs, isMonolith) } type OldVerifyKeys struct { @@ -156,6 +168,26 @@ func (c *ServerNotices) Defaults(generate bool) { func (c *ServerNotices) Verify(errors *ConfigErrors, isMonolith bool) {} +// ReportStats configures opt-in anonymous stats reporting. +type ReportStats struct { + // Enabled configures anonymous usage stats of the server + Enabled bool `yaml:"enabled"` + + // Endpoint the endpoint to report stats to + Endpoint string `yaml:"endpoint"` +} + +func (c *ReportStats) Defaults() { + c.Enabled = false + c.Endpoint = "https://matrix.org/report-usage-stats/push" +} + +func (c *ReportStats) Verify(configErrs *ConfigErrors, isMonolith bool) { + if c.Enabled { + checkNotEmpty(configErrs, "global.report_stats.endpoint", c.Endpoint) + } +} + // The configuration to use for Sentry error reporting type Sentry struct { Enabled bool `yaml:"enabled"` diff --git a/setup/config/config_keyserver.go b/setup/config/config_keyserver.go index 6180ccbc8..9e2d54cdc 100644 --- a/setup/config/config_keyserver.go +++ b/setup/config/config_keyserver.go @@ -20,5 +20,7 @@ func (c *KeyServer) Defaults(generate bool) { func (c *KeyServer) Verify(configErrs *ConfigErrors, isMonolith bool) { checkURL(configErrs, "key_server.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "key_server.internal_api.bind", string(c.InternalAPI.Connect)) - checkNotEmpty(configErrs, "key_server.database.connection_string", string(c.Database.ConnectionString)) + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "key_server.database.connection_string", string(c.Database.ConnectionString)) + } } diff --git a/setup/config/config_mediaapi.go b/setup/config/config_mediaapi.go index 9a7d84969..273de322a 100644 --- a/setup/config/config_mediaapi.go +++ b/setup/config/config_mediaapi.go @@ -23,7 +23,7 @@ type MediaAPI struct { // The maximum file size in bytes that is allowed to be stored on this server. // Note: if max_file_size_bytes is set to 0, the size is unlimited. // Note: if max_file_size_bytes is not set, it will default to 10485760 (10MB) - MaxFileSizeBytes *FileSizeBytes `yaml:"max_file_size_bytes,omitempty"` + MaxFileSizeBytes FileSizeBytes `yaml:"max_file_size_bytes,omitempty"` // Whether to dynamically generate thumbnails on-the-fly if the requested resolution is not already generated DynamicThumbnails bool `yaml:"dynamic_thumbnails"` @@ -48,7 +48,7 @@ func (c *MediaAPI) Defaults(generate bool) { c.BasePath = "./media_store" } - c.MaxFileSizeBytes = &DefaultMaxFileSizeBytes + c.MaxFileSizeBytes = DefaultMaxFileSizeBytes c.MaxThumbnailGenerators = 10 } @@ -58,10 +58,12 @@ func (c *MediaAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { if !isMonolith { checkURL(configErrs, "media_api.external_api.listen", string(c.ExternalAPI.Listen)) } - checkNotEmpty(configErrs, "media_api.database.connection_string", string(c.Database.ConnectionString)) + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "media_api.database.connection_string", string(c.Database.ConnectionString)) + } checkNotEmpty(configErrs, "media_api.base_path", string(c.BasePath)) - checkPositive(configErrs, "media_api.max_file_size_bytes", int64(*c.MaxFileSizeBytes)) + checkPositive(configErrs, "media_api.max_file_size_bytes", int64(c.MaxFileSizeBytes)) checkPositive(configErrs, "media_api.max_thumbnail_generators", int64(c.MaxThumbnailGenerators)) for i, size := range c.ThumbnailSizes { diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index 66a4c80c9..b992f7152 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -31,5 +31,7 @@ func (c *MSCs) Enabled(msc string) bool { } func (c *MSCs) Verify(configErrs *ConfigErrors, isMonolith bool) { - checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) + } } diff --git a/setup/config/config_roomserver.go b/setup/config/config_roomserver.go index 73abb4f47..8a3227349 100644 --- a/setup/config/config_roomserver.go +++ b/setup/config/config_roomserver.go @@ -20,5 +20,7 @@ func (c *RoomServer) Defaults(generate bool) { func (c *RoomServer) Verify(configErrs *ConfigErrors, isMonolith bool) { checkURL(configErrs, "room_server.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "room_server.internal_ap.bind", string(c.InternalAPI.Connect)) - checkNotEmpty(configErrs, "room_server.database.connection_string", string(c.Database.ConnectionString)) + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "room_server.database.connection_string", string(c.Database.ConnectionString)) + } } diff --git a/setup/config/config_syncapi.go b/setup/config/config_syncapi.go index dc813cb7d..48fd9f506 100644 --- a/setup/config/config_syncapi.go +++ b/setup/config/config_syncapi.go @@ -27,5 +27,7 @@ func (c *SyncAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { if !isMonolith { checkURL(configErrs, "sync_api.external_api.listen", string(c.ExternalAPI.Listen)) } - checkNotEmpty(configErrs, "sync_api.database", string(c.Database.ConnectionString)) + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "sync_api.database", string(c.Database.ConnectionString)) + } } diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index 570dc6030..4aa3b57bb 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -37,6 +37,8 @@ func (c *UserAPI) Defaults(generate bool) { func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) - checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) + } checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) } diff --git a/setup/monolith.go b/setup/monolith.go index c86ec7b69..41a897024 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -15,7 +15,6 @@ package setup import ( - "github.com/gorilla/mux" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi" "github.com/matrix-org/dendrite/clientapi/api" @@ -25,11 +24,10 @@ import ( keyAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/mediaapi" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi" userapi "github.com/matrix-org/dendrite/userapi/api" - userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -37,12 +35,11 @@ import ( // all components of Dendrite, for use in monolith mode. type Monolith struct { Config *config.Dendrite - AccountDB userdb.Database KeyRing *gomatrixserverlib.KeyRing Client *gomatrixserverlib.Client FedClient *gomatrixserverlib.FederationClient - AppserviceAPI appserviceAPI.AppServiceQueryAPI + AppserviceAPI appserviceAPI.AppServiceInternalAPI FederationAPI federationAPI.FederationInternalAPI RoomserverAPI roomserverAPI.RoomserverInternalAPI UserAPI userapi.UserInternalAPI @@ -50,30 +47,28 @@ type Monolith struct { // Optional ExtPublicRoomsProvider api.ExtraPublicRoomsProvider - ExtUserDirectoryProvider userapi.UserDirectoryProvider + ExtUserDirectoryProvider userapi.QuerySearchProfilesAPI } // AddAllPublicRoutes attaches all public paths to the given router -func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux, dendriteMux *mux.Router) { +func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { userDirectoryProvider := m.ExtUserDirectoryProvider if userDirectoryProvider == nil { userDirectoryProvider = m.UserAPI } clientapi.AddPublicRoutes( - process, csMux, synapseMux, dendriteMux, &m.Config.ClientAPI, - m.FedClient, m.RoomserverAPI, - m.AppserviceAPI, transactions.New(), + base, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI, - m.ExtPublicRoomsProvider, &m.Config.MSCs, + m.ExtPublicRoomsProvider, ) federationapi.AddPublicRoutes( - process, ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient, - m.KeyRing, m.RoomserverAPI, m.FederationAPI, - m.KeyAPI, &m.Config.MSCs, nil, + base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, + m.KeyAPI, nil, + ) + mediaapi.AddPublicRoutes( + base, m.UserAPI, m.Client, ) - mediaapi.AddPublicRoutes(mediaMux, &m.Config.MediaAPI, &m.Config.ClientAPI.RateLimiting, m.UserAPI, m.Client) syncapi.AddPublicRoutes( - process, csMux, m.UserAPI, m.RoomserverAPI, - m.KeyAPI, m.FedClient, &m.Config.SyncAPI, + base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, ) } diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 29c781a88..452b14580 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -102,7 +102,7 @@ func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, ) error { - db, err := NewDatabase(&base.Cfg.MSCs.Database) + db, err := NewDatabase(base, &base.Cfg.MSCs.Database) if err != nil { return fmt.Errorf("cannot enable MSC2836: %w", err) } diff --git a/setup/mscs/msc2836/storage.go b/setup/mscs/msc2836/storage.go index 72523916b..827e82f70 100644 --- a/setup/mscs/msc2836/storage.go +++ b/setup/mscs/msc2836/storage.go @@ -8,6 +8,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -58,19 +59,17 @@ type DB struct { } // NewDatabase loads the database for msc2836 -func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) { if dbOpts.ConnectionString.IsPostgres() { - return newPostgresDatabase(dbOpts) + return newPostgresDatabase(base, dbOpts) } - return newSQLiteDatabase(dbOpts) + return newSQLiteDatabase(base, dbOpts) } -func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewDummyWriter(), - } +func newPostgresDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{} var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { + if d.db, d.writer, err = base.DatabaseConnection(dbOpts, sqlutil.NewDummyWriter()); err != nil { return nil, err } _, err = d.db.Exec(` @@ -145,12 +144,10 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { return &d, err } -func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewExclusiveWriter(), - } +func newSQLiteDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{} var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { + if d.db, d.writer, err = base.DatabaseConnection(dbOpts, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } _, err = d.db.Exec(` diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index e806f76e6..c8d88ddac 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -42,8 +42,7 @@ type OutputKeyChangeEventConsumer struct { notifier *notifier.Notifier stream types.StreamProvider serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.RoomserverInternalAPI - keyAPI api.KeyInternalAPI + rsAPI roomserverAPI.SyncRoomserverAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. @@ -53,8 +52,7 @@ func NewOutputKeyChangeEventConsumer( cfg *config.SyncAPI, topic string, js nats.JetStreamContext, - keyAPI api.KeyInternalAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.SyncRoomserverAPI, store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, @@ -66,7 +64,6 @@ func NewOutputKeyChangeEventConsumer( topic: topic, db: store, serverName: cfg.Matrix.ServerName, - keyAPI: keyAPI, rsAPI: rsAPI, notifier: notifier, stream: stream, diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 6bcca48f4..388c08ff4 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -41,7 +41,7 @@ type PresenceConsumer struct { db storage.Database stream types.StreamProvider notifier *notifier.Notifier - deviceAPI api.UserDeviceAPI + deviceAPI api.SyncUserAPI cfg *config.SyncAPI } @@ -55,7 +55,7 @@ func NewPresenceConsumer( db storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, - deviceAPI api.UserDeviceAPI, + deviceAPI api.SyncUserAPI, ) *PresenceConsumer { return &PresenceConsumer{ ctx: process.Context(), diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 5bdc0fad7..7712c8403 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -38,7 +38,7 @@ import ( type OutputRoomEventConsumer struct { ctx context.Context cfg *config.SyncAPI - rsAPI api.RoomserverInternalAPI + rsAPI api.SyncRoomserverAPI jetstream nats.JetStreamContext durable string topic string @@ -58,7 +58,7 @@ func NewOutputRoomEventConsumer( notifier *notifier.Notifier, pduStream types.StreamProvider, inviteStream types.StreamProvider, - rsAPI api.RoomserverInternalAPI, + rsAPI api.SyncRoomserverAPI, producer *producers.UserAPIStreamEventProducer, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index dc4acd8da..d96718d20 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -29,7 +29,7 @@ import ( const DeviceListLogName = "dl" // DeviceOTKCounts adds one-time key counts to the /sync response -func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, deviceID string, res *types.Response) error { +func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { var queryRes keyapi.QueryOneTimeKeysResponse keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ UserID: userID, @@ -46,7 +46,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. func DeviceListCatchup( - ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, + ctx context.Context, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, res *types.Response, from, to types.StreamPosition, ) (newPos types.StreamPosition, hasNew bool, err error) { @@ -130,7 +130,7 @@ func DeviceListCatchup( // TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response. func TrackChangedUsers( - ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string, + ctx context.Context, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string, ) (changed, left []string, err error) { // process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users. @@ -216,7 +216,7 @@ func TrackChangedUsers( } func filterSharedUsers( - ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, userID string, usersWithChangedKeys []string, + ctx context.Context, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, usersWithChangedKeys []string, ) (map[string]int, []string) { var result []string var sharedUsersRes roomserverAPI.QuerySharedUsersResponse diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 17215b669..87cc2aae0 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -42,10 +42,10 @@ type ContextRespsonse struct { func Context( req *http.Request, device *userapi.Device, - rsAPI roomserver.RoomserverInternalAPI, + rsAPI roomserver.SyncRoomserverAPI, syncDB storage.Database, roomID, eventID string, - lazyLoadCache *caching.LazyLoadCache, + lazyLoadCache caching.LazyLoadCache, ) util.JSONResponse { filter, err := parseRoomEventFilter(req) if err != nil { @@ -155,7 +155,7 @@ func applyLazyLoadMembers( filter *gomatrixserverlib.RoomEventFilter, eventsAfter, eventsBefore []gomatrixserverlib.ClientEvent, state []*gomatrixserverlib.HeaderedEvent, - lazyLoadCache *caching.LazyLoadCache, + lazyLoadCache caching.LazyLoadCache, ) []*gomatrixserverlib.HeaderedEvent { if filter == nil || !filter.LazyLoadMembers { return state diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index f34901bf2..b0c990ec0 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -36,8 +36,7 @@ import ( type messagesReq struct { ctx context.Context db storage.Database - rsAPI api.RoomserverInternalAPI - federation *gomatrixserverlib.FederationClient + rsAPI api.SyncRoomserverAPI cfg *config.SyncAPI roomID string from *types.TopologyToken @@ -61,11 +60,10 @@ type messagesResp struct { // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages func OnIncomingMessagesRequest( req *http.Request, db storage.Database, roomID string, device *userapi.Device, - federation *gomatrixserverlib.FederationClient, - rsAPI api.RoomserverInternalAPI, + rsAPI api.SyncRoomserverAPI, cfg *config.SyncAPI, srp *sync.RequestPool, - lazyLoadCache *caching.LazyLoadCache, + lazyLoadCache caching.LazyLoadCache, ) util.JSONResponse { var err error @@ -180,7 +178,6 @@ func OnIncomingMessagesRequest( ctx: req.Context(), db: db, rsAPI: rsAPI, - federation: federation, cfg: cfg, roomID: roomID, from: &from, @@ -247,7 +244,7 @@ func OnIncomingMessagesRequest( } } -func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.RoomserverInternalAPI) (bool, error) { +func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (bool, error) { req := api.QueryMembershipForUserRequest{ RoomID: roomID, UserID: userID, diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 4102cf073..6bc495d8d 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -36,10 +36,10 @@ import ( // nolint: gocyclo func Setup( csMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, - userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, - rsAPI api.RoomserverInternalAPI, + userAPI userapi.SyncUserAPI, + rsAPI api.SyncRoomserverAPI, cfg *config.SyncAPI, - lazyLoadCache *caching.LazyLoadCache, + lazyLoadCache caching.LazyLoadCache, ) { v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() @@ -53,7 +53,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp, lazyLoadCache) + return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, rsAPI, cfg, srp, lazyLoadCache) })).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/filter", diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index b0382512a..9cfe7c070 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -21,6 +21,7 @@ import ( // Import the postgres database driver. _ "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/shared" @@ -35,13 +36,12 @@ type SyncServerDatasource struct { } // NewDatabase creates a new sync server database -func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { return nil, err } - d.writer = sqlutil.NewDummyWriter() accountData, err := NewPostgresAccountDataTable(d.db) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index dfc289482..e08a0ba82 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" @@ -35,13 +36,12 @@ type SyncServerDatasource struct { // NewDatabase creates a new sync server database // nolint: gocyclo -func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } - d.writer = sqlutil.NewExclusiveWriter() if err = d.prepare(dbProperties); err != nil { return nil, err } diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index 7f9c28e9d..5b20c6cc2 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -20,18 +20,19 @@ package storage import ( "fmt" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) // NewSyncServerDatasource opens a database connection. -func NewSyncServerDatasource(dbProperties *config.DatabaseOptions) (Database, error) { +func NewSyncServerDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) + return sqlite3.NewDatabase(base, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties) + return postgres.NewDatabase(base, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 15bb769a2..1150c2f3d 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -17,7 +17,7 @@ var ctx = context.Background() func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewSyncServerDatasource(&config.DatabaseOptions{ + db, err := storage.NewSyncServerDatasource(nil, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { diff --git a/syncapi/storage/storage_wasm.go b/syncapi/storage/storage_wasm.go index f7fef962b..c15444743 100644 --- a/syncapi/storage/storage_wasm.go +++ b/syncapi/storage/storage_wasm.go @@ -17,15 +17,16 @@ package storage import ( "fmt" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) // NewPublicRoomsServerDatabase opens a database connection. -func NewSyncServerDatasource(dbProperties *config.DatabaseOptions) (Database, error) { +func NewSyncServerDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) + return sqlite3.NewDatabase(base, dbProperties) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go index a143e5ecd..8bbf879d4 100644 --- a/syncapi/storage/tables/output_room_events_test.go +++ b/syncapi/storage/tables/output_room_events_test.go @@ -21,7 +21,7 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := sqlutil.Open(&config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }) + }, sqlutil.NewExclusiveWriter()) if err != nil { t.Fatalf("failed to open db: %s", err) } diff --git a/syncapi/storage/tables/topology_test.go b/syncapi/storage/tables/topology_test.go index b6ece0b0d..2334aae2e 100644 --- a/syncapi/storage/tables/topology_test.go +++ b/syncapi/storage/tables/topology_test.go @@ -20,7 +20,7 @@ func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.D connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := sqlutil.Open(&config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }) + }, sqlutil.NewExclusiveWriter()) if err != nil { t.Fatalf("failed to open db: %s", err) } diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 2cddbcf04..9c19b846b 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -10,7 +10,7 @@ import ( type AccountDataStreamProvider struct { StreamProvider - userAPI userapi.UserInternalAPI + userAPI userapi.SyncUserAPI } func (p *AccountDataStreamProvider) Setup() { diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 6ff8a7fd5..f42099510 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -11,8 +11,8 @@ import ( type DeviceListStreamProvider struct { StreamProvider - rsAPI api.RoomserverInternalAPI - keyAPI keyapi.KeyInternalAPI + rsAPI api.SyncRoomserverAPI + keyAPI keyapi.SyncKeyAPI } func (p *DeviceListStreamProvider) CompleteSync( diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 0d033095d..00b3dfe3b 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -32,8 +32,8 @@ type PDUStreamProvider struct { tasks chan func() workers atomic.Int32 // userID+deviceID -> lazy loading cache - lazyLoadCache *caching.LazyLoadCache - rsAPI roomserverAPI.RoomserverInternalAPI + lazyLoadCache caching.LazyLoadCache + rsAPI roomserverAPI.SyncRoomserverAPI } func (p *PDUStreamProvider) worker() { diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index a84d19878..35ce53cb6 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -145,6 +145,10 @@ func (p *PresenceStreamProvider) IncrementalSync( p.cache.Store(cacheKey, presence) } + if len(req.Response.Presence.Events) == 0 { + return to + } + return lastPos } diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index a18a0cc41..1ca4ee8c3 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -25,9 +25,9 @@ type Streams struct { } func NewSyncStreamProviders( - d storage.Database, userAPI userapi.UserInternalAPI, - rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI, - eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier, + d storage.Database, userAPI userapi.SyncUserAPI, + rsAPI rsapi.SyncRoomserverAPI, keyAPI keyapi.SyncKeyAPI, + eduCache *caching.EDUCache, lazyLoadCache caching.LazyLoadCache, notifier *notifier.Notifier, ) *Streams { streams := &Streams{ PDUStreamProvider: &PDUStreamProvider{ diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 76d550a65..99d1e40c3 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -45,9 +45,9 @@ import ( type RequestPool struct { db storage.Database cfg *config.SyncAPI - userAPI userapi.UserInternalAPI - keyAPI keyapi.KeyInternalAPI - rsAPI roomserverAPI.RoomserverInternalAPI + userAPI userapi.SyncUserAPI + keyAPI keyapi.SyncKeyAPI + rsAPI roomserverAPI.SyncRoomserverAPI lastseen *sync.Map presence *sync.Map streams *streams.Streams @@ -62,8 +62,8 @@ type PresencePublisher interface { // NewRequestPool makes a new RequestPool func NewRequestPool( db storage.Database, cfg *config.SyncAPI, - userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, + userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI, + rsAPI roomserverAPI.SyncRoomserverAPI, streams *streams.Streams, notifier *notifier.Notifier, producer PresencePublisher, ) *RequestPool { @@ -182,6 +182,7 @@ func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) UserID: device.UserID, DeviceID: device.ID, RemoteAddr: remoteAddr, + UserAgent: req.UserAgent(), } lsres := &userapi.PerformLastSeenUpdateResponse{} go rp.userAPI.PerformLastSeenUpdate(req.Context(), lsreq, lsres) // nolint:errcheck diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index b2d333f74..6da8ce6d1 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -17,17 +17,14 @@ package syncapi import ( "context" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/caching" "github.com/sirupsen/logrus" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/syncapi/consumers" "github.com/matrix-org/dendrite/syncapi/notifier" @@ -41,28 +38,23 @@ import ( // AddPublicRoutes sets up and registers HTTP handlers for the SyncAPI // component. func AddPublicRoutes( - process *process.ProcessContext, - router *mux.Router, - userAPI userapi.UserInternalAPI, - rsAPI api.RoomserverInternalAPI, - keyAPI keyapi.KeyInternalAPI, - federation *gomatrixserverlib.FederationClient, - cfg *config.SyncAPI, + base *base.BaseDendrite, + userAPI userapi.SyncUserAPI, + rsAPI api.SyncRoomserverAPI, + keyAPI keyapi.SyncKeyAPI, ) { - js, natsClient := jetstream.Prepare(process, &cfg.Matrix.JetStream) + cfg := &base.Cfg.SyncAPI - syncDB, err := storage.NewSyncServerDatasource(&cfg.Database) + js, natsClient := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + + syncDB, err := storage.NewSyncServerDatasource(base, &cfg.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to sync db") } eduCache := caching.NewTypingCache() - lazyLoadCache, err := caching.NewLazyLoadCache() - if err != nil { - logrus.WithError(err).Panicf("failed to create lazy loading cache") - } notifier := notifier.NewNotifier() - streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier) + streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, base.Caches, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { logrus.WithError(err).Panicf("failed to load notifier ") @@ -86,8 +78,8 @@ func AddPublicRoutes( } keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( - process, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), - js, keyAPI, rsAPI, syncDB, notifier, + base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + js, rsAPI, syncDB, notifier, streams.DeviceListStreamProvider, ) if err = keyChangeConsumer.Start(); err != nil { @@ -95,7 +87,7 @@ func AddPublicRoutes( } roomConsumer := consumers.NewOutputRoomEventConsumer( - process, cfg, js, syncDB, notifier, streams.PDUStreamProvider, + base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, ) if err = roomConsumer.Start(); err != nil { @@ -103,7 +95,7 @@ func AddPublicRoutes( } clientConsumer := consumers.NewOutputClientDataConsumer( - process, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, + base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, userAPIReadUpdateProducer, ) if err = clientConsumer.Start(); err != nil { @@ -111,28 +103,28 @@ func AddPublicRoutes( } notificationConsumer := consumers.NewOutputNotificationDataConsumer( - process, cfg, js, syncDB, notifier, streams.NotificationDataStreamProvider, + base.ProcessContext, cfg, js, syncDB, notifier, streams.NotificationDataStreamProvider, ) if err = notificationConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start notification data consumer") } typingConsumer := consumers.NewOutputTypingEventConsumer( - process, cfg, js, eduCache, notifier, streams.TypingStreamProvider, + base.ProcessContext, cfg, js, eduCache, notifier, streams.TypingStreamProvider, ) if err = typingConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start typing consumer") } sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( - process, cfg, js, syncDB, notifier, streams.SendToDeviceStreamProvider, + base.ProcessContext, cfg, js, syncDB, notifier, streams.SendToDeviceStreamProvider, ) if err = sendToDeviceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start send-to-device consumer") } receiptConsumer := consumers.NewOutputReceiptEventConsumer( - process, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, + base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, userAPIReadUpdateProducer, ) if err = receiptConsumer.Start(); err != nil { @@ -140,7 +132,7 @@ func AddPublicRoutes( } presenceConsumer := consumers.NewPresenceConsumer( - process, cfg, js, natsClient, syncDB, + base.ProcessContext, cfg, js, natsClient, syncDB, notifier, streams.PresenceStreamProvider, userAPI, ) @@ -148,5 +140,8 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start presence consumer") } - routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg, lazyLoadCache) + routing.Setup( + base.PublicClientAPIMux, requestPool, syncDB, userAPI, + rsAPI, cfg, base.Caches, + ) } diff --git a/test/db.go b/test/db.go index 6412feaa6..a1754cd08 100644 --- a/test/db.go +++ b/test/db.go @@ -33,10 +33,19 @@ var DBTypeSQLite DBType = 1 var DBTypePostgres DBType = 2 var Quiet = false +var Required = os.Getenv("DENDRITE_TEST_SKIP_NODB") == "" -func createLocalDB(dbName string) { +func fatalError(t *testing.T, format string, args ...interface{}) { + if Required { + t.Fatalf(format, args...) + } else { + t.Skipf(format, args...) + } +} + +func createLocalDB(t *testing.T, dbName string) { if !Quiet { - fmt.Println("Note: tests require a postgres install accessible to the current user") + t.Log("Note: tests require a postgres install accessible to the current user") } createDB := exec.Command("createdb", dbName) if !Quiet { @@ -52,7 +61,7 @@ func createLocalDB(dbName string) { func createRemoteDB(t *testing.T, dbName, user, connStr string) { db, err := sql.Open("postgres", connStr+" dbname=postgres") if err != nil { - t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err) + fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err) } _, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName)) if err != nil { @@ -133,7 +142,7 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo hash := sha256.Sum256([]byte(wd)) dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16])) if postgresDB == "" { // local server, use createdb - createLocalDB(dbName) + createLocalDB(t, dbName) } else { // remote server, shell into the postgres user and CREATE DATABASE createRemoteDB(t, dbName, user, connStr) } diff --git a/userapi/api/api.go b/userapi/api/api.go index 6aa6a6842..df9408acb 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -26,73 +26,102 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { - LoginTokenInternalAPI - UserProfileAPI - UserRegisterAPI - UserAccountAPI - UserThreePIDAPI - UserDeviceAPI + AppserviceUserAPI + SyncUserAPI + ClientUserAPI + MediaUserAPI + FederationUserAPI + RoomserverUserAPI + KeyserverUserAPI - InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error - - PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error - PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error - PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error - PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error - PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error - - QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) - QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error - QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error - QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error - QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error - QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error - QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error + QuerySearchProfilesAPI // used by p2p demos } -type UserDeviceAPI interface { - PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error +// api functions required by the appservice api +type AppserviceUserAPI interface { + PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error + PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error +} + +type KeyserverUserAPI interface { + QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error + QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error +} + +type RoomserverUserAPI interface { + QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error +} + +// api functions required by the media api +type MediaUserAPI interface { + QueryAcccessTokenAPI +} + +// api functions required by the federation api +type FederationUserAPI interface { + QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error + QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error +} + +// api functions required by the sync api +type SyncUserAPI interface { + QueryAcccessTokenAPI + QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error } -type UserDirectoryProvider interface { - QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error -} - -// UserProfileAPI provides functions for getting user profiles -type UserProfileAPI interface { - QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error - QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error - SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error - SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error -} - -// UserRegisterAPI defines functions for registering accounts -type UserRegisterAPI interface { +// api functions required by the client api +type ClientUserAPI interface { + QueryAcccessTokenAPI + LoginTokenInternalAPI + UserLoginAPI QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error + QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error + QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error + QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error + QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error -} - -// UserAccountAPI defines functions for changing an account -type UserAccountAPI interface { + PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error + PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error + PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error + PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error + PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error - QueryAccountByPassword(ctx context.Context, req *QueryAccountByPasswordRequest, res *QueryAccountByPasswordResponse) error -} + PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error + SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error + SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error + QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error + PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) -// UserThreePIDAPI defines functions for 3PID -type UserThreePIDAPI interface { - QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error + QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error } +// custom api functions required by pinecone / p2p demos +type QuerySearchProfilesAPI interface { + QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error +} + +// common function for creating authenticated endpoints (used in client/media/sync api) +type QueryAcccessTokenAPI interface { + QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error +} + +type UserLoginAPI interface { + QueryAccountByPassword(ctx context.Context, req *QueryAccountByPasswordRequest, res *QueryAccountByPasswordResponse) error +} + type PerformKeyBackupRequest struct { UserID string Version string // optional if modifying a key backup @@ -320,6 +349,7 @@ type PerformLastSeenUpdateRequest struct { UserID string DeviceID string RemoteAddr string + UserAgent string } // PerformLastSeenUpdateResponse is the response for PerformLastSeenUpdate. diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go index 9ef7b5083..7807c7637 100644 --- a/userapi/consumers/syncapi_streamevent.go +++ b/userapi/consumers/syncapi_streamevent.go @@ -29,7 +29,7 @@ type OutputStreamEventConsumer struct { ctx context.Context cfg *config.UserAPI userAPI api.UserInternalAPI - rsAPI rsapi.RoomserverInternalAPI + rsAPI rsapi.UserRoomserverAPI jetstream nats.JetStreamContext durable string db storage.Database @@ -45,7 +45,7 @@ func NewOutputStreamEventConsumer( store storage.Database, pgClient pushgateway.Client, userAPI api.UserInternalAPI, - rsAPI rsapi.RoomserverInternalAPI, + rsAPI rsapi.UserRoomserverAPI, syncProducer *producers.SyncAPI, ) *OutputStreamEventConsumer { return &OutputStreamEventConsumer{ @@ -455,7 +455,7 @@ func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event type ruleSetEvalContext struct { ctx context.Context - rsAPI rsapi.RoomserverInternalAPI + rsAPI rsapi.UserRoomserverAPI mem *localMembership roomID string roomSize int diff --git a/userapi/internal/api.go b/userapi/internal/api.go index be58e2d8d..9d2f63c72 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -48,7 +48,7 @@ type UserInternalAPI struct { ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService - KeyAPI keyapi.KeyInternalAPI + KeyAPI keyapi.UserKeyAPI } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -210,7 +210,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate( if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index a4562cf19..f7cd1810a 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" ) type Profile interface { @@ -67,7 +68,7 @@ type Device interface { // Returns the device on success. CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error + UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) @@ -135,9 +136,14 @@ type Database interface { OpenID Profile Pusher + Statistics ThreePID } +type Statistics interface { + UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) +} + // Err3PIDInUse is the error returned when trying to save an association involving // a third-party identifier which is already associated to a local user. var Err3PIDInUse = errors.New("this third-party identifier is already in use") diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 6c777982f..ccb776672 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -96,7 +96,7 @@ const selectDevicesByIDSQL = "" + "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" + "UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" type devicesStatements struct { insertDeviceStmt *sql.Stmt @@ -304,9 +304,9 @@ func (s *devicesStatements) SelectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) return err } diff --git a/userapi/storage/postgres/stats_table.go b/userapi/storage/postgres/stats_table.go new file mode 100644 index 000000000..c0b317503 --- /dev/null +++ b/userapi/storage/postgres/stats_table.go @@ -0,0 +1,437 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +const userDailyVisitsSchema = ` +CREATE TABLE IF NOT EXISTS userapi_daily_visits ( + localpart TEXT NOT NULL, + device_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + user_agent TEXT +); + +-- Device IDs and timestamp must be unique for a given user per day +CREATE UNIQUE INDEX IF NOT EXISTS userapi_daily_visits_localpart_device_timestamp_idx ON userapi_daily_visits(localpart, device_id, timestamp); +CREATE INDEX IF NOT EXISTS userapi_daily_visits_timestamp_idx ON userapi_daily_visits(timestamp); +CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON userapi_daily_visits(localpart, timestamp); +` + +const countUsersLastSeenAfterSQL = "" + + "SELECT COUNT(*) FROM (" + + " SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + + " GROUP BY localpart" + + " ) u" + +// Note on the following countR30UsersSQL and countR30UsersV2SQL: The different checks are intentional. +// This is to ensure the values reported by Dendrite are the same as by Synapse. +// Queries are taken from: https://github.com/matrix-org/synapse/blob/9ce51a47f6e37abd0a1275281806399d874eb026/synapse/storage/databases/main/stats.py + +/* +R30Users counts the number of 30 day retained users, defined as: +- Users who have created their accounts more than 30 days ago +- Where last seen at most 30 days ago +- Where account creation and last_seen are > 30 days apart +*/ +const countR30UsersSQL = ` +SELECT platform, COUNT(*) FROM ( + SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) + FROM account_accounts users + INNER JOIN + (SELECT + localpart, last_seen_ts, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM device_devices + ) uip + ON users.localpart = uip.localpart + AND users.account_type <> 4 + AND users.created_ts < $1 + AND uip.last_seen_ts > $1 + AND (uip.last_seen_ts) - users.created_ts > $2 + GROUP BY users.localpart, platform, users.created_ts + ) u GROUP BY PLATFORM +` + +/* +R30UsersV2 counts the number of 30 day retained users, defined as users that: +- Appear more than once in the past 60 days +- Have more than 30 days between the most and least recent appearances that occurred in the past 60 days. +*/ +const countR30UsersV2SQL = ` +SELECT + client_type, + count(client_type) +FROM + ( + SELECT + localpart, + CASE + WHEN + LOWER(user_agent) LIKE '%%riot%%' OR + LOWER(user_agent) LIKE '%%element%%' + THEN CASE + WHEN LOWER(user_agent) LIKE '%%electron%%' THEN 'electron' + WHEN LOWER(user_agent) LIKE '%%android%%' THEN 'android' + WHEN LOWER(user_agent) LIKE '%%ios%%' THEN 'ios' + ELSE 'unknown' + END + WHEN LOWER(user_agent) LIKE '%%mozilla%%' OR LOWER(user_agent) LIKE '%%gecko%%' THEN 'web' + ELSE 'unknown' + END as client_type + FROM userapi_daily_visits + WHERE timestamp > $1 AND timestamp < $2 + GROUP BY localpart, client_type + HAVING max(timestamp) - min(timestamp) > $3 + ) AS temp +GROUP BY client_type +` + +const countUserByAccountTypeSQL = ` +SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1) +` + +// $1 = All non guest AccountType IDs +// $2 = Guest AccountType +const countRegisteredUserByTypeStmt = ` +SELECT user_type, COUNT(*) AS count FROM ( + SELECT + CASE + WHEN account_type = ANY($1) AND appservice_id IS NULL THEN 'native' + WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest' + WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged' + END AS user_type + FROM account_accounts + WHERE created_ts > $3 +) AS t GROUP BY user_type +` + +// account_type 1 = users; 3 = admins +const updateUserDailyVisitsSQL = ` +INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) + SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) + FROM device_devices AS u + LEFT JOIN ( + SELECT localpart, device_id, timestamp FROM userapi_daily_visits + WHERE timestamp = $1 + ) udv + ON u.localpart = udv.localpart AND u.device_id = udv.device_id + INNER JOIN device_devices d ON d.localpart = u.localpart + INNER JOIN account_accounts a ON a.localpart = u.localpart + WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 + AND a.account_type in (1, 3) + GROUP BY u.localpart, u.device_id +ON CONFLICT (localpart, device_id, timestamp) DO NOTHING +; +` + +const queryDBEngineVersion = "SHOW server_version;" + +type statsStatements struct { + serverName gomatrixserverlib.ServerName + lastUpdate time.Time + countUsersLastSeenAfterStmt *sql.Stmt + countR30UsersStmt *sql.Stmt + countR30UsersV2Stmt *sql.Stmt + updateUserDailyVisitsStmt *sql.Stmt + countUserByAccountTypeStmt *sql.Stmt + countRegisteredUserByTypeStmt *sql.Stmt + dbEngineVersionStmt *sql.Stmt +} + +func NewPostgresStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.StatsTable, error) { + s := &statsStatements{ + serverName: serverName, + lastUpdate: time.Now(), + } + + _, err := db.Exec(userDailyVisitsSchema) + if err != nil { + return nil, err + } + go s.startTimers() + return s, sqlutil.StatementList{ + {&s.countUsersLastSeenAfterStmt, countUsersLastSeenAfterSQL}, + {&s.countR30UsersStmt, countR30UsersSQL}, + {&s.countR30UsersV2Stmt, countR30UsersV2SQL}, + {&s.updateUserDailyVisitsStmt, updateUserDailyVisitsSQL}, + {&s.countUserByAccountTypeStmt, countUserByAccountTypeSQL}, + {&s.countRegisteredUserByTypeStmt, countRegisteredUserByTypeStmt}, + {&s.dbEngineVersionStmt, queryDBEngineVersion}, + }.Prepare(db) +} + +func (s *statsStatements) startTimers() { + var updateStatsFunc func() + updateStatsFunc = func() { + logrus.Infof("Executing UpdateUserDailyVisits") + if err := s.UpdateUserDailyVisits(context.Background(), nil, time.Now(), s.lastUpdate); err != nil { + logrus.WithError(err).Error("failed to update daily user visits") + } + time.AfterFunc(time.Hour*3, updateStatsFunc) + } + time.AfterFunc(time.Minute*5, updateStatsFunc) +} + +func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { + stmt := sqlutil.TxStmt(txn, s.countUserByAccountTypeStmt) + err = stmt.QueryRowContext(ctx, + pq.Int64Array{ + int64(api.AccountTypeUser), + int64(api.AccountTypeGuest), + int64(api.AccountTypeAdmin), + int64(api.AccountTypeAppService), + }, + ).Scan(&result) + return +} + +func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { + stmt := sqlutil.TxStmt(txn, s.countUserByAccountTypeStmt) + err = stmt.QueryRowContext(ctx, + pq.Int64Array{ + int64(api.AccountTypeUser), + int64(api.AccountTypeGuest), + int64(api.AccountTypeAdmin), + }, + ).Scan(&result) + return +} + +func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) (map[string]int64, error) { + stmt := sqlutil.TxStmt(txn, s.countRegisteredUserByTypeStmt) + registeredAfter := time.Now().AddDate(0, 0, -30) + + rows, err := stmt.QueryContext(ctx, + pq.Int64Array{ + int64(api.AccountTypeUser), + int64(api.AccountTypeAdmin), + int64(api.AccountTypeAppService), + }, + api.AccountTypeGuest, + gomatrixserverlib.AsTimestamp(registeredAfter), + ) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "RegisteredUserByType: failed to close rows") + + var userType string + var count int64 + var result = make(map[string]int64) + for rows.Next() { + if err = rows.Scan(&userType, &count); err != nil { + return nil, err + } + result[userType] = count + } + + return result, rows.Err() +} + +func (s *statsStatements) dailyUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { + stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) + lastSeenAfter := time.Now().AddDate(0, 0, -1) + err = stmt.QueryRowContext(ctx, + gomatrixserverlib.AsTimestamp(lastSeenAfter), + ).Scan(&result) + return +} + +func (s *statsStatements) monthlyUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { + stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) + lastSeenAfter := time.Now().AddDate(0, 0, -30) + err = stmt.QueryRowContext(ctx, + gomatrixserverlib.AsTimestamp(lastSeenAfter), + ).Scan(&result) + return +} + +/* +R30Users counts the number of 30 day retained users, defined as: +- Users who have created their accounts more than 30 days ago +- Where last seen at most 30 days ago +- Where account creation and last_seen are > 30 days apart +*/ +func (s *statsStatements) r30Users(ctx context.Context, txn *sql.Tx) (map[string]int64, error) { + stmt := sqlutil.TxStmt(txn, s.countR30UsersStmt) + lastSeenAfter := time.Now().AddDate(0, 0, -30) + diff := time.Hour * 24 * 30 + + rows, err := stmt.QueryContext(ctx, + gomatrixserverlib.AsTimestamp(lastSeenAfter), + diff.Milliseconds(), + ) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "R30Users: failed to close rows") + + var platform string + var count int64 + var result = make(map[string]int64) + for rows.Next() { + if err = rows.Scan(&platform, &count); err != nil { + return nil, err + } + if platform == "unknown" { + continue + } + result["all"] += count + result[platform] = count + } + + return result, rows.Err() +} + +/* +R30UsersV2 counts the number of 30 day retained users, defined as users that: +- Appear more than once in the past 60 days +- Have more than 30 days between the most and least recent appearances that occurred in the past 60 days. +*/ +func (s *statsStatements) r30UsersV2(ctx context.Context, txn *sql.Tx) (map[string]int64, error) { + stmt := sqlutil.TxStmt(txn, s.countR30UsersV2Stmt) + sixtyDaysAgo := time.Now().AddDate(0, 0, -60) + diff := time.Hour * 24 * 30 + tomorrow := time.Now().Add(time.Hour * 24) + + rows, err := stmt.QueryContext(ctx, + gomatrixserverlib.AsTimestamp(sixtyDaysAgo), + gomatrixserverlib.AsTimestamp(tomorrow), + diff.Milliseconds(), + ) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "R30UsersV2: failed to close rows") + + var platform string + var count int64 + var result = map[string]int64{ + "ios": 0, + "android": 0, + "web": 0, + "electron": 0, + "all": 0, + } + for rows.Next() { + if err = rows.Scan(&platform, &count); err != nil { + return nil, err + } + if _, ok := result[platform]; !ok { + continue + } + result["all"] += count + result[platform] = count + } + + return result, rows.Err() +} + +// UserStatistics collects some information about users on this instance. +// Returns the stats itself as well as the database engine version and type. +// On error, returns the stats collected up to the error. +func (s *statsStatements) UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error) { + var ( + stats = &types.UserStatistics{ + R30UsersV2: map[string]int64{ + "ios": 0, + "android": 0, + "web": 0, + "electron": 0, + "all": 0, + }, + R30Users: map[string]int64{}, + RegisteredUsersByType: map[string]int64{}, + } + dbEngine = &types.DatabaseEngine{Engine: "Postgres", Version: "unknown"} + err error + ) + stats.AllUsers, err = s.allUsers(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.DailyUsers, err = s.dailyUsers(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.MonthlyUsers, err = s.monthlyUsers(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.R30Users, err = s.r30Users(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.R30UsersV2, err = s.r30UsersV2(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.NonBridgedUsers, err = s.nonBridgedUsers(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.RegisteredUsersByType, err = s.registeredUserByType(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + + stmt := sqlutil.TxStmt(txn, s.dbEngineVersionStmt) + err = stmt.QueryRowContext(ctx).Scan(&dbEngine.Version) + return stats, dbEngine, err +} + +func (s *statsStatements) UpdateUserDailyVisits( + ctx context.Context, txn *sql.Tx, + startTime, lastUpdate time.Time, +) error { + stmt := sqlutil.TxStmt(txn, s.updateUserDailyVisitsStmt) + startTime = startTime.Truncate(time.Hour * 24) + + // edge case + if startTime.After(s.lastUpdate) { + startTime = startTime.AddDate(0, 0, -1) + } + _, err := stmt.ExecContext(ctx, + gomatrixserverlib.AsTimestamp(startTime), + gomatrixserverlib.AsTimestamp(lastUpdate), + gomatrixserverlib.AsTimestamp(time.Now()), + ) + if err == nil { + s.lastUpdate = time.Now() + } + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index b2a517605..b9afb5a56 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/shared" @@ -30,8 +31,8 @@ import ( ) // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { - db, err := sqlutil.Open(dbProperties) +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) if err != nil { return nil, err } @@ -93,6 +94,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) } + statsTable, err := NewPostgresStatsTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -105,9 +110,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver ThreePIDs: threePIDTable, Pushers: pusherTable, Notifications: notificationsTable, + Stats: statsTable, ServerName: serverName, DB: db, - Writer: sqlutil.NewDummyWriter(), + Writer: writer, LoginTokenLifetime: loginTokenLifetime, BcryptCost: bcryptCost, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index f7212e030..0cf713dac 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -26,6 +26,7 @@ import ( "strings" "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" @@ -51,6 +52,7 @@ type Database struct { LoginTokens tables.LoginTokenTable Notifications tables.NotificationTable Pushers tables.PusherTable + Stats tables.StatsTable LoginTokenLifetime time.Duration ServerName gomatrixserverlib.ServerName BcryptCost int @@ -611,10 +613,10 @@ func (d *Database) RemoveAllDevices( return } -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { +// UpdateDeviceLastSeen updates a last seen timestamp and the ip address. +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent) }) } @@ -756,3 +758,8 @@ func (d *Database) RemovePushers( return d.Pushers.DeletePushers(ctx, txn, appid, pushkey) }) } + +// UserStatistics populates types.UserStatistics, used in reports. +func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) { + return d.Stats.UserStatistics(ctx, nil) +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index b86ed1cc2..93291e6ad 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -81,7 +81,7 @@ const selectDevicesByIDSQL = "" + "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" + "UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" type devicesStatements struct { db *sql.DB @@ -306,9 +306,9 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) return err } diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go new file mode 100644 index 000000000..e00ed417b --- /dev/null +++ b/userapi/storage/sqlite3/stats_table.go @@ -0,0 +1,452 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +const userDailyVisitsSchema = ` +CREATE TABLE IF NOT EXISTS userapi_daily_visits ( + localpart TEXT NOT NULL, + device_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + user_agent TEXT +); + +-- Device IDs and timestamp must be unique for a given user per day +CREATE UNIQUE INDEX IF NOT EXISTS userapi_daily_visits_localpart_device_timestamp_idx ON userapi_daily_visits(localpart, device_id, timestamp); +CREATE INDEX IF NOT EXISTS userapi_daily_visits_timestamp_idx ON userapi_daily_visits(timestamp); +CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON userapi_daily_visits(localpart, timestamp); +` + +const countUsersLastSeenAfterSQL = "" + + "SELECT COUNT(*) FROM (" + + " SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + + " GROUP BY localpart" + + " ) u" + +// Note on the following countR30UsersSQL and countR30UsersV2SQL: The different checks are intentional. +// This is to ensure the values reported by Dendrite are the same as by Synapse. +// Queries are taken from: https://github.com/matrix-org/synapse/blob/9ce51a47f6e37abd0a1275281806399d874eb026/synapse/storage/databases/main/stats.py + +/* +R30Users counts the number of 30 day retained users, defined as: +- Users who have created their accounts more than 30 days ago +- Where last seen at most 30 days ago +- Where account creation and last_seen are > 30 days apart +*/ +const countR30UsersSQL = ` +SELECT platform, COUNT(*) FROM ( + SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) + FROM account_accounts users + INNER JOIN + (SELECT + localpart, last_seen_ts, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM device_devices + ) uip + ON users.localpart = uip.localpart + AND users.account_type <> 4 + AND users.created_ts < $1 + AND uip.last_seen_ts > $2 + AND (uip.last_seen_ts) - users.created_ts > $3 + GROUP BY users.localpart, platform, users.created_ts + ) u GROUP BY PLATFORM +` + +// Note on the following countR30UsersSQL and countR30UsersV2SQL: The different checks are intentional. +// This is to ensure the values reported are the same as Synapse reports. +// Queries are taken from: https://github.com/matrix-org/synapse/blob/9ce51a47f6e37abd0a1275281806399d874eb026/synapse/storage/databases/main/stats.py + +/* +R30UsersV2 counts the number of 30 day retained users, defined as users that: +- Appear more than once in the past 60 days +- Have more than 30 days between the most and least recent appearances that occurred in the past 60 days. +*/ +const countR30UsersV2SQL = ` +SELECT + client_type, + count(client_type) +FROM + ( + SELECT + localpart, + CASE + WHEN + LOWER(user_agent) LIKE '%%riot%%' OR + LOWER(user_agent) LIKE '%%element%%' + THEN CASE + WHEN LOWER(user_agent) LIKE '%%electron%%' THEN 'electron' + WHEN LOWER(user_agent) LIKE '%%android%%' THEN 'android' + WHEN LOWER(user_agent) LIKE '%%ios%%' THEN 'ios' + ELSE 'unknown' + END + WHEN LOWER(user_agent) LIKE '%%mozilla%%' OR LOWER(user_agent) LIKE '%%gecko%%' THEN 'web' + ELSE 'unknown' + END as client_type + FROM userapi_daily_visits + WHERE timestamp > $1 AND timestamp < $2 + GROUP BY localpart, client_type + HAVING max(timestamp) - min(timestamp) > $3 + ) AS temp +GROUP BY client_type +` + +const countUserByAccountTypeSQL = ` +SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1) +` + +// $1 = Guest AccountType +// $3 & $4 = All non guest AccountType IDs +const countRegisteredUserByTypeSQL = ` +SELECT user_type, COUNT(*) AS count FROM ( + SELECT + CASE + WHEN account_type IN ($1) AND appservice_id IS NULL THEN 'native' + WHEN account_type = $4 AND appservice_id IS NULL THEN 'guest' + WHEN account_type IN ($5) AND appservice_id IS NOT NULL THEN 'bridged' + END AS user_type + FROM account_accounts + WHERE created_ts > $8 +) AS t GROUP BY user_type +` + +// account_type 1 = users; 3 = admins +const updateUserDailyVisitsSQL = ` +INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) + SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) + FROM device_devices AS u + LEFT JOIN ( + SELECT localpart, device_id, timestamp FROM userapi_daily_visits + WHERE timestamp = $1 + ) udv + ON u.localpart = udv.localpart AND u.device_id = udv.device_id + INNER JOIN device_devices d ON d.localpart = u.localpart + INNER JOIN account_accounts a ON a.localpart = u.localpart + WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 + AND a.account_type in (1, 3) + GROUP BY u.localpart, u.device_id +ON CONFLICT (localpart, device_id, timestamp) DO NOTHING +; +` + +const queryDBEngineVersion = "select sqlite_version();" + +type statsStatements struct { + serverName gomatrixserverlib.ServerName + db *sql.DB + lastUpdate time.Time + countUsersLastSeenAfterStmt *sql.Stmt + countR30UsersStmt *sql.Stmt + countR30UsersV2Stmt *sql.Stmt + updateUserDailyVisitsStmt *sql.Stmt + countUserByAccountTypeStmt *sql.Stmt + countRegisteredUserByTypeStmt *sql.Stmt + dbEngineVersionStmt *sql.Stmt +} + +func NewSQLiteStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.StatsTable, error) { + s := &statsStatements{ + serverName: serverName, + lastUpdate: time.Now(), + db: db, + } + + _, err := db.Exec(userDailyVisitsSchema) + if err != nil { + return nil, err + } + go s.startTimers() + return s, sqlutil.StatementList{ + {&s.countUsersLastSeenAfterStmt, countUsersLastSeenAfterSQL}, + {&s.countR30UsersStmt, countR30UsersSQL}, + {&s.countR30UsersV2Stmt, countR30UsersV2SQL}, + {&s.updateUserDailyVisitsStmt, updateUserDailyVisitsSQL}, + {&s.countUserByAccountTypeStmt, countUserByAccountTypeSQL}, + {&s.countRegisteredUserByTypeStmt, countRegisteredUserByTypeSQL}, + {&s.dbEngineVersionStmt, queryDBEngineVersion}, + }.Prepare(db) +} + +func (s *statsStatements) startTimers() { + var updateStatsFunc func() + updateStatsFunc = func() { + logrus.Infof("Executing UpdateUserDailyVisits") + if err := s.UpdateUserDailyVisits(context.Background(), nil, time.Now(), s.lastUpdate); err != nil { + logrus.WithError(err).Error("failed to update daily user visits") + } + time.AfterFunc(time.Hour*3, updateStatsFunc) + } + time.AfterFunc(time.Minute*5, updateStatsFunc) +} + +func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { + query := strings.Replace(countUserByAccountTypeSQL, "($1)", sqlutil.QueryVariadic(4), 1) + queryStmt, err := s.db.Prepare(query) + if err != nil { + return 0, err + } + stmt := sqlutil.TxStmt(txn, queryStmt) + err = stmt.QueryRowContext(ctx, + 1, 2, 3, 4, + ).Scan(&result) + return +} + +func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { + query := strings.Replace(countUserByAccountTypeSQL, "($1)", sqlutil.QueryVariadic(3), 1) + queryStmt, err := s.db.Prepare(query) + if err != nil { + return 0, err + } + stmt := sqlutil.TxStmt(txn, queryStmt) + err = stmt.QueryRowContext(ctx, + 1, 2, 3, + ).Scan(&result) + return +} + +func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) (map[string]int64, error) { + // $1 = Guest AccountType; $2 = timestamp + // $3 & $4 = All non guest AccountType IDs + nonGuests := []api.AccountType{api.AccountTypeUser, api.AccountTypeAdmin, api.AccountTypeAppService} + countSQL := strings.Replace(countRegisteredUserByTypeSQL, "($1)", sqlutil.QueryVariadicOffset(len(nonGuests), 0), 1) + countSQL = strings.Replace(countSQL, "($5)", sqlutil.QueryVariadicOffset(len(nonGuests), 1+len(nonGuests)), 1) + queryStmt, err := s.db.Prepare(countSQL) + if err != nil { + return nil, err + } + stmt := sqlutil.TxStmt(txn, queryStmt) + registeredAfter := time.Now().AddDate(0, 0, -30) + + params := make([]interface{}, len(nonGuests)*2+2) + // nonGuests is used twice + for i, v := range nonGuests { + params[i] = v // i: 0 1 2 => ($1, $2, $3) + params[i+1+len(nonGuests)] = v // i: 4 5 6 => ($5, $6, $7) + } + params[3] = api.AccountTypeGuest // $4 + params[7] = gomatrixserverlib.AsTimestamp(registeredAfter) // $8 + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "RegisteredUserByType: failed to close rows") + + var userType string + var count int64 + var result = make(map[string]int64) + for rows.Next() { + if err = rows.Scan(&userType, &count); err != nil { + return nil, err + } + result[userType] = count + } + + return result, rows.Err() +} + +func (s *statsStatements) dailyUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { + stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) + lastSeenAfter := time.Now().AddDate(0, 0, -1) + err = stmt.QueryRowContext(ctx, + gomatrixserverlib.AsTimestamp(lastSeenAfter), + ).Scan(&result) + return +} + +func (s *statsStatements) monthlyUsers(ctx context.Context, txn *sql.Tx) (result int64, err error) { + stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) + lastSeenAfter := time.Now().AddDate(0, 0, -30) + err = stmt.QueryRowContext(ctx, + gomatrixserverlib.AsTimestamp(lastSeenAfter), + ).Scan(&result) + return +} + +/* R30Users counts the number of 30 day retained users, defined as: +- Users who have created their accounts more than 30 days ago +- Where last seen at most 30 days ago +- Where account creation and last_seen are > 30 days apart +*/ +func (s *statsStatements) r30Users(ctx context.Context, txn *sql.Tx) (map[string]int64, error) { + stmt := sqlutil.TxStmt(txn, s.countR30UsersStmt) + lastSeenAfter := time.Now().AddDate(0, 0, -30) + diff := time.Hour * 24 * 30 + + rows, err := stmt.QueryContext(ctx, + gomatrixserverlib.AsTimestamp(lastSeenAfter), + gomatrixserverlib.AsTimestamp(lastSeenAfter), + diff.Milliseconds(), + ) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "R30Users: failed to close rows") + + var platform string + var count int64 + var result = make(map[string]int64) + for rows.Next() { + if err = rows.Scan(&platform, &count); err != nil { + return nil, err + } + if platform == "unknown" { + continue + } + result["all"] += count + result[platform] = count + } + + return result, rows.Err() +} + +/* R30UsersV2 counts the number of 30 day retained users, defined as users that: +- Appear more than once in the past 60 days +- Have more than 30 days between the most and least recent appearances that occurred in the past 60 days. +*/ +func (s *statsStatements) r30UsersV2(ctx context.Context, txn *sql.Tx) (map[string]int64, error) { + stmt := sqlutil.TxStmt(txn, s.countR30UsersV2Stmt) + sixtyDaysAgo := time.Now().AddDate(0, 0, -60) + diff := time.Hour * 24 * 30 + tomorrow := time.Now().Add(time.Hour * 24) + + rows, err := stmt.QueryContext(ctx, + gomatrixserverlib.AsTimestamp(sixtyDaysAgo), + gomatrixserverlib.AsTimestamp(tomorrow), + diff.Milliseconds(), + ) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "R30UsersV2: failed to close rows") + + var platform string + var count int64 + var result = map[string]int64{ + "ios": 0, + "android": 0, + "web": 0, + "electron": 0, + "all": 0, + } + for rows.Next() { + if err = rows.Scan(&platform, &count); err != nil { + return nil, err + } + if _, ok := result[platform]; !ok { + continue + } + result["all"] += count + result[platform] = count + } + return result, rows.Err() +} + +// UserStatistics collects some information about users on this instance. +// Returns the stats itself as well as the database engine version and type. +// On error, returns the stats collected up to the error. +func (s *statsStatements) UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error) { + var ( + stats = &types.UserStatistics{ + R30UsersV2: map[string]int64{ + "ios": 0, + "android": 0, + "web": 0, + "electron": 0, + "all": 0, + }, + R30Users: map[string]int64{}, + RegisteredUsersByType: map[string]int64{}, + } + dbEngine = &types.DatabaseEngine{Engine: "SQLite", Version: "unknown"} + err error + ) + stats.AllUsers, err = s.allUsers(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.DailyUsers, err = s.dailyUsers(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.MonthlyUsers, err = s.monthlyUsers(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.R30Users, err = s.r30Users(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.R30UsersV2, err = s.r30UsersV2(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.NonBridgedUsers, err = s.nonBridgedUsers(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + stats.RegisteredUsersByType, err = s.registeredUserByType(ctx, txn) + if err != nil { + return stats, dbEngine, err + } + + stmt := sqlutil.TxStmt(txn, s.dbEngineVersionStmt) + err = stmt.QueryRowContext(ctx).Scan(&dbEngine.Version) + return stats, dbEngine, err +} + +func (s *statsStatements) UpdateUserDailyVisits( + ctx context.Context, txn *sql.Tx, + startTime, lastUpdate time.Time, +) error { + stmt := sqlutil.TxStmt(txn, s.updateUserDailyVisitsStmt) + startTime = startTime.Truncate(time.Hour * 24) + + // edge case + if startTime.After(s.lastUpdate) { + startTime = startTime.AddDate(0, 0, -1) + } + _, err := stmt.ExecContext(ctx, + gomatrixserverlib.AsTimestamp(startTime), + gomatrixserverlib.AsTimestamp(lastUpdate), + gomatrixserverlib.AsTimestamp(time.Now()), + ) + if err == nil { + s.lastUpdate = time.Now() + } + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 03c013f00..a822f687d 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/shared" @@ -31,8 +32,8 @@ import ( ) // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { - db, err := sqlutil.Open(dbProperties) +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err } @@ -94,6 +95,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) } + statsTable, err := NewSQLiteStatsTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -106,9 +111,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver ThreePIDs: threePIDTable, Pushers: pusherTable, Notifications: notificationsTable, + Stats: statsTable, ServerName: serverName, DB: db, - Writer: sqlutil.NewExclusiveWriter(), + Writer: writer, LoginTokenLifetime: loginTokenLifetime, BcryptCost: bcryptCost, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index faf1ce75c..42221e752 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/postgres" "github.com/matrix-org/dendrite/userapi/storage/sqlite3" @@ -30,12 +31,12 @@ import ( // NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewUserAPIDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { +func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 2eb57d0bc..5683fe067 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -29,7 +29,7 @@ var ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(&config.DatabaseOptions{ + db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { @@ -168,13 +168,13 @@ func Test_Devices(t *testing.T) { devices2, err := db.GetDevicesByID(ctx, deviceIDs) assert.NoError(t, err, "unable to get devices by id") - assert.Equal(t, devices, devices2) + assert.ElementsMatch(t, devices, devices2) // Update device newName := "new display name" err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName) assert.NoError(t, err, "unable to update device displayname") - err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1") + err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web") assert.NoError(t, err, "unable to update device last seen") deviceWithID.DisplayName = newName diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index a8e6f031c..5d5d292e6 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -18,12 +18,14 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) func NewUserAPIDatabase( + base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, @@ -33,7 +35,7 @@ func NewUserAPIDatabase( ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index eb0cae314..2fe955670 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -18,9 +18,11 @@ import ( "context" "database/sql" "encoding/json" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/types" ) type AccountDataTable interface { @@ -48,7 +50,7 @@ type DevicesTable interface { SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error } type KeyBackupTable interface { @@ -111,6 +113,11 @@ type NotificationTable interface { SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) } +type StatsTable interface { + UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error) + UpdateUserDailyVisits(ctx context.Context, txn *sql.Tx, startTime, lastUpdate time.Time) error +} + type NotificationFilter uint32 const ( diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go new file mode 100644 index 000000000..11521c8b0 --- /dev/null +++ b/userapi/storage/tables/stats_table_test.go @@ -0,0 +1,319 @@ +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func mustMakeDBs(t *testing.T, dbType test.DBType) ( + *sql.DB, tables.AccountsTable, tables.DevicesTable, tables.StatsTable, func(), +) { + t.Helper() + + var ( + accTable tables.AccountsTable + devTable tables.DevicesTable + statsTable tables.StatsTable + err error + ) + + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, nil) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + switch dbType { + case test.DBTypeSQLite: + accTable, err = sqlite3.NewSQLiteAccountsTable(db, "localhost") + if err != nil { + t.Fatalf("unable to create acc db: %v", err) + } + devTable, err = sqlite3.NewSQLiteDevicesTable(db, "localhost") + if err != nil { + t.Fatalf("unable to open device db: %v", err) + } + statsTable, err = sqlite3.NewSQLiteStatsTable(db, "localhost") + if err != nil { + t.Fatalf("unable to open stats db: %v", err) + } + case test.DBTypePostgres: + accTable, err = postgres.NewPostgresAccountsTable(db, "localhost") + if err != nil { + t.Fatalf("unable to create acc db: %v", err) + } + devTable, err = postgres.NewPostgresDevicesTable(db, "localhost") + if err != nil { + t.Fatalf("unable to open device db: %v", err) + } + statsTable, err = postgres.NewPostgresStatsTable(db, "localhost") + if err != nil { + t.Fatalf("unable to open stats db: %v", err) + } + } + + return db, accTable, devTable, statsTable, close +} + +func mustMakeAccountAndDevice( + t *testing.T, + ctx context.Context, + accDB tables.AccountsTable, + devDB tables.DevicesTable, + localpart string, + accType api.AccountType, + userAgent string, +) { + t.Helper() + + appServiceID := "" + if accType == api.AccountTypeAppService { + appServiceID = util.RandomString(16) + } + + _, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType) + if err != nil { + t.Fatalf("unable to create account: %v", err) + } + _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent) + if err != nil { + t.Fatalf("unable to create device: %v", err) + } +} + +func mustUpdateDeviceLastSeen( + t *testing.T, + ctx context.Context, + db *sql.DB, + localpart string, + timestamp time.Time, +) { + t.Helper() + _, err := db.ExecContext(ctx, "UPDATE device_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + if err != nil { + t.Fatalf("unable to update device last seen") + } +} + +func mustUserUpdateRegistered( + t *testing.T, + ctx context.Context, + db *sql.DB, + localpart string, + timestamp time.Time, +) { + _, err := db.ExecContext(ctx, "UPDATE account_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + if err != nil { + t.Fatalf("unable to update device last seen") + } +} + +// These tests must run sequentially, as they build up on each other +func Test_UserStatistics(t *testing.T) { + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, accDB, devDB, statsDB, close := mustMakeDBs(t, dbType) + defer close() + wantType := "SQLite" + if dbType == test.DBTypePostgres { + wantType = "Postgres" + } + + t.Run(fmt.Sprintf("want %s database engine", wantType), func(t *testing.T) { + _, gotDB, err := statsDB.UserStatistics(ctx, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if wantType != gotDB.Engine { // can't use DeepEqual, as the Version might differ + t.Errorf("UserStatistics() got DB engine = %+v, want %s", gotDB.Engine, wantType) + } + }) + + t.Run("Want Users", func(t *testing.T) { + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko") + gotStats, _, err := statsDB.UserStatistics(ctx, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + wantStats := &types.UserStatistics{ + RegisteredUsersByType: map[string]int64{ + "native": 4, + "guest": 1, + "bridged": 1, + }, + R30Users: map[string]int64{}, + R30UsersV2: map[string]int64{ + "ios": 0, + "android": 0, + "web": 0, + "electron": 0, + "all": 0, + }, + AllUsers: 6, + NonBridgedUsers: 5, + DailyUsers: 6, + MonthlyUsers: 6, + } + if !reflect.DeepEqual(gotStats, wantStats) { + t.Errorf("UserStatistics() gotStats = \n%+v\nwant\n%+v", gotStats, wantStats) + } + }) + + t.Run("Users not active for one/two month", func(t *testing.T) { + mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0)) + mustUpdateDeviceLastSeen(t, ctx, db, "user2", time.Now().AddDate(0, -1, 0)) + gotStats, _, err := statsDB.UserStatistics(ctx, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + wantStats := &types.UserStatistics{ + RegisteredUsersByType: map[string]int64{ + "native": 4, + "guest": 1, + "bridged": 1, + }, + R30Users: map[string]int64{}, + R30UsersV2: map[string]int64{ + "ios": 0, + "android": 0, + "web": 0, + "electron": 0, + "all": 0, + }, + AllUsers: 6, + NonBridgedUsers: 5, + DailyUsers: 4, + MonthlyUsers: 4, + } + if !reflect.DeepEqual(gotStats, wantStats) { + t.Errorf("UserStatistics() gotStats = \n%+v\nwant\n%+v", gotStats, wantStats) + } + }) + + /* R30Users counts the number of 30 day retained users, defined as: + - Users who have created their accounts more than 30 days ago + - Where last seen at most 30 days ago + - Where account creation and last_seen are > 30 days apart + */ + t.Run("R30Users tests", func(t *testing.T) { + mustUserUpdateRegistered(t, ctx, db, "user1", time.Now().AddDate(0, -2, 0)) + mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now()) + mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, -2, 0)) + mustUpdateDeviceLastSeen(t, ctx, db, "user4", time.Now()) + startTime := time.Now().AddDate(0, 0, -2) + err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24).Add(time.Hour)) + if err != nil { + t.Fatalf("unable to update daily visits stats: %v", err) + } + + gotStats, _, err := statsDB.UserStatistics(ctx, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + wantStats := &types.UserStatistics{ + RegisteredUsersByType: map[string]int64{ + "native": 3, + "bridged": 1, + }, + R30Users: map[string]int64{ + "all": 2, + "android": 1, + "electron": 1, + }, + R30UsersV2: map[string]int64{ + "ios": 0, + "android": 0, + "web": 0, + "electron": 0, + "all": 0, + }, + AllUsers: 6, + NonBridgedUsers: 5, + DailyUsers: 5, + MonthlyUsers: 5, + } + if !reflect.DeepEqual(gotStats, wantStats) { + t.Errorf("UserStatistics() gotStats = \n%+v\nwant\n%+v", gotStats, wantStats) + } + }) + + /* + R30UsersV2 counts the number of 30 day retained users, defined as users that: + - Appear more than once in the past 60 days + - Have more than 30 days between the most and least recent appearances that occurred in the past 60 days. + most recent -> neueste + least recent -> älteste + + */ + t.Run("R30UsersV2 tests", func(t *testing.T) { + // generate some data + for i := 100; i > 0; i-- { + mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, 0, -i)) + mustUpdateDeviceLastSeen(t, ctx, db, "user5", time.Now().AddDate(0, 0, -i)) + startTime := time.Now().AddDate(0, 0, -i) + err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24).Add(time.Hour)) + if err != nil { + t.Fatalf("unable to update daily visits stats: %v", err) + } + } + gotStats, _, err := statsDB.UserStatistics(ctx, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + wantStats := &types.UserStatistics{ + RegisteredUsersByType: map[string]int64{ + "native": 3, + "bridged": 1, + }, + R30Users: map[string]int64{ + "all": 2, + "android": 1, + "electron": 1, + }, + R30UsersV2: map[string]int64{ + "ios": 0, + "android": 1, + "web": 1, + "electron": 0, + "all": 2, + }, + AllUsers: 6, + NonBridgedUsers: 5, + DailyUsers: 3, + MonthlyUsers: 5, + } + if !reflect.DeepEqual(gotStats, wantStats) { + t.Errorf("UserStatistics() gotStats = \n%+v\nwant\n%+v", gotStats, wantStats) + } + }) + }) + +} diff --git a/userapi/types/statistics.go b/userapi/types/statistics.go new file mode 100644 index 000000000..09564f78f --- /dev/null +++ b/userapi/types/statistics.go @@ -0,0 +1,30 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +type UserStatistics struct { + RegisteredUsersByType map[string]int64 + R30Users map[string]int64 + R30UsersV2 map[string]int64 + AllUsers int64 + NonBridgedUsers int64 + DailyUsers int64 + MonthlyUsers int64 +} + +type DatabaseEngine struct { + Engine string + Version string +} diff --git a/userapi/userapi.go b/userapi/userapi.go index e91ce3a7a..03a46807f 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/util" "github.com/sirupsen/logrus" ) @@ -42,12 +43,25 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI, - appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, - rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client, + base *base.BaseDendrite, cfg *config.UserAPI, + appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI, + rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client, ) api.UserInternalAPI { js, _ := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + db, err := storage.NewUserAPIDatabase( + base, + &cfg.AccountDatabase, + cfg.Matrix.ServerName, + cfg.BCryptCost, + cfg.OpenIDTokenLifetimeMS, + api.DefaultLoginTokenLifetime, + cfg.Matrix.ServerNotices.LocalPart, + ) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to accounts db") + } + syncProducer := producers.NewSyncAPI( db, js, // TODO: user API should handle syncs for account data. Right now, @@ -91,5 +105,9 @@ func NewInternalAPI( } time.AfterFunc(time.Minute, cleanOldNotifs) + if base.Cfg.Global.ReportStats.Enabled { + go util.StartPhoneHomeCollector(time.Now(), base.Cfg, db) + } + return userAPI } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 076b4f3c6..e614765a2 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package userapi +package userapi_test import ( "context" @@ -23,15 +23,17 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + internalTest "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" - "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -43,16 +45,15 @@ type apiTestOpts struct { loginTokenLifetime time.Duration } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) { +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) { if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } - dbopts := &config.DatabaseOptions{ - ConnectionString: "file::memory:", - MaxOpenConnections: 1, - MaxIdleConnections: 1, - } - accountDB, err := storage.NewUserAPIDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") + connStr, close := test.PrepareDBConnectionString(t, dbType) + + accountDB, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } @@ -66,13 +67,15 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s return &internal.UserInternalAPI{ DB: accountDB, ServerName: cfg.Matrix.ServerName, - }, accountDB + }, accountDB, close } func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) + // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) + defer close() _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) @@ -131,8 +134,8 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - AddInternalRoutes(router, userAPI) - apiURL, cancel := test.ListenAndServe(t, router, false) + userapi.AddInternalRoutes(router, userAPI) + apiURL, cancel := internalTest.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) if err != nil { @@ -149,110 +152,120 @@ func TestLoginToken(t *testing.T) { ctx := context.Background() t.Run("tokenLoginFlow", func(t *testing.T) { - userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) + if err != nil { + t.Fatalf("failed to make account: %s", err) + } - _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) - if err != nil { - t.Fatalf("failed to make account: %s", err) - } + t.Log("Creating a login token like the SSO callback would...") - t.Log("Creating a login token like the SSO callback would...") + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } - creq := api.PerformLoginTokenCreationRequest{ - Data: api.LoginTokenData{UserID: "@auser:example.com"}, - } - var cresp api.PerformLoginTokenCreationResponse - if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { - t.Fatalf("PerformLoginTokenCreation failed: %v", err) - } + if cresp.Metadata.Token == "" { + t.Errorf("PerformLoginTokenCreation Token: got %q, want non-empty", cresp.Metadata.Token) + } + if cresp.Metadata.Expiration.Before(time.Now()) { + t.Errorf("PerformLoginTokenCreation Expiration: got %v, want non-expired", cresp.Metadata.Expiration) + } - if cresp.Metadata.Token == "" { - t.Errorf("PerformLoginTokenCreation Token: got %q, want non-empty", cresp.Metadata.Token) - } - if cresp.Metadata.Expiration.Before(time.Now()) { - t.Errorf("PerformLoginTokenCreation Expiration: got %v, want non-expired", cresp.Metadata.Expiration) - } + t.Log("Querying the login token like /login with m.login.token would...") - t.Log("Querying the login token like /login with m.login.token would...") + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } - qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} - var qresp api.QueryLoginTokenResponse - if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { - t.Fatalf("QueryLoginToken failed: %v", err) - } + if qresp.Data == nil { + t.Errorf("QueryLoginToken Data: got %v, want non-nil", qresp.Data) + } else if want := "@auser:example.com"; qresp.Data.UserID != want { + t.Errorf("QueryLoginToken UserID: got %q, want %q", qresp.Data.UserID, want) + } - if qresp.Data == nil { - t.Errorf("QueryLoginToken Data: got %v, want non-nil", qresp.Data) - } else if want := "@auser:example.com"; qresp.Data.UserID != want { - t.Errorf("QueryLoginToken UserID: got %q, want %q", qresp.Data.UserID, want) - } + t.Log("Deleting the login token like /login with m.login.token would...") - t.Log("Deleting the login token like /login with m.login.token would...") - - dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} - var dresp api.PerformLoginTokenDeletionResponse - if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { - t.Fatalf("PerformLoginTokenDeletion failed: %v", err) - } + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) }) t.Run("expiredTokenIsNotReturned", func(t *testing.T) { - userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType) + defer close() - creq := api.PerformLoginTokenCreationRequest{ - Data: api.LoginTokenData{UserID: "@auser:example.com"}, - } - var cresp api.PerformLoginTokenCreationResponse - if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { - t.Fatalf("PerformLoginTokenCreation failed: %v", err) - } + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } - qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} - var qresp api.QueryLoginTokenResponse - if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { - t.Fatalf("QueryLoginToken failed: %v", err) - } + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } - if qresp.Data != nil { - t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) - } + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) }) t.Run("deleteWorks", func(t *testing.T) { - userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() - creq := api.PerformLoginTokenCreationRequest{ - Data: api.LoginTokenData{UserID: "@auser:example.com"}, - } - var cresp api.PerformLoginTokenCreationResponse - if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { - t.Fatalf("PerformLoginTokenCreation failed: %v", err) - } + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } - dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} - var dresp api.PerformLoginTokenDeletionResponse - if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { - t.Fatalf("PerformLoginTokenDeletion failed: %v", err) - } + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } - qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} - var qresp api.QueryLoginTokenResponse - if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { - t.Fatalf("QueryLoginToken failed: %v", err) - } + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } - if qresp.Data != nil { - t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) - } + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) }) t.Run("deleteUnknownIsNoOp", func(t *testing.T) { - userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) - - dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} - var dresp api.PerformLoginTokenDeletionResponse - if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { - t.Fatalf("PerformLoginTokenDeletion failed: %v", err) - } + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) }) } diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go new file mode 100644 index 000000000..e24daba6b --- /dev/null +++ b/userapi/util/phonehomestats.go @@ -0,0 +1,158 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "bytes" + "context" + "encoding/json" + "math" + "net/http" + "runtime" + "syscall" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +type phoneHomeStats struct { + prevData timestampToRUUsage + stats map[string]interface{} + serverName gomatrixserverlib.ServerName + startTime time.Time + cfg *config.Dendrite + db storage.Statistics + isMonolith bool + client *gomatrixserverlib.Client +} + +type timestampToRUUsage struct { + timestamp int64 + usage syscall.Rusage +} + +func StartPhoneHomeCollector(startTime time.Time, cfg *config.Dendrite, statsDB storage.Statistics) { + + p := phoneHomeStats{ + startTime: startTime, + serverName: cfg.Global.ServerName, + cfg: cfg, + db: statsDB, + isMonolith: cfg.IsMonolith, + client: gomatrixserverlib.NewClient( + gomatrixserverlib.WithTimeout(time.Second * 30), + ), + } + + // start initial run after 5min + time.AfterFunc(time.Minute*5, p.collect) + + // run every 3 hours + ticker := time.NewTicker(time.Hour * 3) + for range ticker.C { + p.collect() + } +} + +func (p *phoneHomeStats) collect() { + p.stats = make(map[string]interface{}) + // general information + p.stats["homeserver"] = p.serverName + p.stats["monolith"] = p.isMonolith + p.stats["version"] = internal.VersionString() + p.stats["timestamp"] = time.Now().Unix() + p.stats["go_version"] = runtime.Version() + p.stats["go_arch"] = runtime.GOARCH + p.stats["go_os"] = runtime.GOOS + p.stats["num_cpu"] = runtime.NumCPU() + p.stats["num_go_routine"] = runtime.NumGoroutine() + p.stats["uptime_seconds"] = math.Floor(time.Since(p.startTime).Seconds()) + + ctx, cancel := context.WithTimeout(context.TODO(), time.Minute*1) + defer cancel() + + // cpu and memory usage information + err := getMemoryStats(p) + if err != nil { + logrus.WithError(err).Warn("unable to get memory/cpu stats, using defaults") + } + + // configuration information + p.stats["federation_disabled"] = p.cfg.Global.DisableFederation + p.stats["nats_embedded"] = true + p.stats["nats_in_memory"] = p.cfg.Global.JetStream.InMemory + if len(p.cfg.Global.JetStream.Addresses) > 0 { + p.stats["nats_embedded"] = false + p.stats["nats_in_memory"] = false // probably + } + if len(p.cfg.Logging) > 0 { + p.stats["log_level"] = p.cfg.Logging[0].Level + } else { + p.stats["log_level"] = "info" + } + + // message and room stats + // TODO: Find a solution to actually set these values + p.stats["total_room_count"] = 0 + p.stats["daily_messages"] = 0 + p.stats["daily_sent_messages"] = 0 + p.stats["daily_e2ee_messages"] = 0 + p.stats["daily_sent_e2ee_messages"] = 0 + + // user stats and DB engine + userStats, db, err := p.db.UserStatistics(ctx) + if err != nil { + logrus.WithError(err).Warn("unable to query userstats, using default values") + } + p.stats["database_engine"] = db.Engine + p.stats["database_server_version"] = db.Version + p.stats["total_users"] = userStats.AllUsers + p.stats["total_nonbridged_users"] = userStats.NonBridgedUsers + p.stats["daily_active_users"] = userStats.DailyUsers + p.stats["monthly_active_users"] = userStats.MonthlyUsers + for t, c := range userStats.RegisteredUsersByType { + p.stats["daily_user_type_"+t] = c + } + for t, c := range userStats.R30Users { + p.stats["r30_users_"+t] = c + } + for t, c := range userStats.R30UsersV2 { + p.stats["r30v2_users_"+t] = c + } + + output := bytes.Buffer{} + if err = json.NewEncoder(&output).Encode(p.stats); err != nil { + logrus.WithError(err).Error("unable to encode anonymous stats") + return + } + + logrus.Infof("Reporting stats to %s: %s", p.cfg.Global.ReportStats.Endpoint, output.String()) + + request, err := http.NewRequestWithContext(ctx, http.MethodPost, p.cfg.Global.ReportStats.Endpoint, &output) + if err != nil { + logrus.WithError(err).Error("unable to create anonymous stats request") + return + } + request.Header.Set("User-Agent", "Dendrite/"+internal.VersionString()) + + if _, err = p.client.DoHTTPRequest(ctx, request); err != nil { + logrus.WithError(err).Error("unable to send anonymous stats") + return + } +} diff --git a/userapi/util/stats.go b/userapi/util/stats.go new file mode 100644 index 000000000..22ef12aad --- /dev/null +++ b/userapi/util/stats.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !wasm && !windows +// +build !wasm,!windows + +package util + +import ( + "syscall" + "time" + + "github.com/sirupsen/logrus" +) + +func getMemoryStats(p *phoneHomeStats) error { + oldUsage := p.prevData + newUsage := syscall.Rusage{} + if err := syscall.Getrusage(syscall.RUSAGE_SELF, &newUsage); err != nil { + logrus.WithError(err).Error("unable to get usage") + return err + } + newData := timestampToRUUsage{timestamp: time.Now().Unix(), usage: newUsage} + p.prevData = newData + + usedCPUTime := (newUsage.Utime.Sec + newUsage.Stime.Sec) - (oldUsage.usage.Utime.Sec + oldUsage.usage.Stime.Sec) + + if usedCPUTime == 0 || newData.timestamp == oldUsage.timestamp { + p.stats["cpu_average"] = 0 + } else { + // conversion to int64 required for GOARCH=386 + p.stats["cpu_average"] = int64(usedCPUTime) / (newData.timestamp - oldUsage.timestamp) * 100 + } + p.stats["memory_rss"] = newUsage.Maxrss + return nil +} diff --git a/userapi/util/stats_wasm.go b/userapi/util/stats_wasm.go new file mode 100644 index 000000000..a182e4e6e --- /dev/null +++ b/userapi/util/stats_wasm.go @@ -0,0 +1,20 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +// stub, since WASM doesn't support syscall.Getrusage +func getMemoryStats(p *phoneHomeStats) error { + return nil +} diff --git a/userapi/util/stats_windows.go b/userapi/util/stats_windows.go new file mode 100644 index 000000000..0b3f8d013 --- /dev/null +++ b/userapi/util/stats_windows.go @@ -0,0 +1,29 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !wasm +// +build !wasm + +package util + +import ( + "runtime" +) + +func getMemoryStats(p *phoneHomeStats) error { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + p.stats["memory_rss"] = memStats.Alloc + return nil +}