diff --git a/appservice/api/query.go b/appservice/api/query.go index 6fcb20890..29e374aca 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -22,10 +22,9 @@ import ( "database/sql" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/internal" ) // RoomAliasExistsRequest is a request to an application service @@ -110,7 +109,7 @@ func RetrieveUserProfile( // If no user exists, return if !userResp.UserIDExists { - return nil, internal.ErrProfileNoExists + return nil, eventutil.ErrProfileNoExists } // Try to query the user from the local database again diff --git a/appservice/appservice.go b/appservice/appservice.go index b7adf755b..728690414 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -16,7 +16,6 @@ package appservice import ( "context" - "errors" "net/http" "sync" "time" @@ -29,12 +28,10 @@ import ( "github.com/matrix-org/dendrite/appservice/storage" "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/workers" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/basecomponent" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/setup" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/sirupsen/logrus" ) @@ -46,9 +43,8 @@ func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceQuer // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *basecomponent.BaseDendrite, - accountsDB accounts.Database, - deviceDB devices.Database, + base *setup.BaseDendrite, + userAPI userapi.UserInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) appserviceAPI.AppServiceQueryAPI { // Create a connection to the appservice postgres DB @@ -70,7 +66,7 @@ func NewInternalAPI( workerStates[i] = ws // Create bot account for this AS if it doesn't already exist - if err = generateAppServiceAccount(accountsDB, deviceDB, appservice); err != nil { + if err = generateAppServiceAccount(userAPI, appservice); err != nil { logrus.WithFields(logrus.Fields{ "appservice": appservice.ID, }).WithError(err).Panicf("failed to generate bot account for appservice") @@ -86,12 +82,16 @@ func NewInternalAPI( Cfg: base.Cfg, } - consumer := consumers.NewOutputRoomEventConsumer( - base.Cfg, base.KafkaConsumer, accountsDB, appserviceDB, - rsAPI, workerStates, - ) - if err := consumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") + // Only consume if we actually have ASes to track, else we'll just chew cycles needlessly. + // We can't add ASes at runtime so this is safe to do. + if len(workerStates) > 0 { + consumer := consumers.NewOutputRoomEventConsumer( + base.Cfg, base.KafkaConsumer, appserviceDB, + rsAPI, workerStates, + ) + if err := consumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") + } } // Create application service transaction workers @@ -105,22 +105,25 @@ func NewInternalAPI( // `sender_localpart` field of each application service if it doesn't // exist already func generateAppServiceAccount( - accountsDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, as config.ApplicationService, ) error { - ctx := context.Background() - - // Create an account for the application service - _, err := accountsDB.CreateAccount(ctx, as.SenderLocalpart, "", as.ID) + var accRes userapi.PerformAccountCreationResponse + err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, + Localpart: as.SenderLocalpart, + AppServiceID: as.ID, + OnConflict: userapi.ConflictUpdate, + }, &accRes) if err != nil { - if errors.Is(err, internal.ErrUserExists) { // This account already exists - return nil - } return err } - - // Create a dummy device with a dummy token for the application service - _, err = deviceDB.CreateDevice(ctx, as.SenderLocalpart, nil, as.ASToken, &as.SenderLocalpart) + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ + Localpart: as.SenderLocalpart, + AccessToken: as.ASToken, + DeviceID: &as.SenderLocalpart, + DeviceDisplayName: &as.SenderLocalpart, + }, &devRes) return err } diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index bb4df7906..4c0156b2c 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/appservice/storage" "github.com/matrix-org/dendrite/appservice/types" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" @@ -33,7 +32,6 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { roomServerConsumer *internal.ContinualConsumer - db accounts.Database asDB storage.Database rsAPI api.RoomserverInternalAPI serverName string @@ -45,7 +43,6 @@ type OutputRoomEventConsumer struct { func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, - store accounts.Database, appserviceDB storage.Database, rsAPI api.RoomserverInternalAPI, workerStates []types.ApplicationServiceWorkerState, @@ -53,11 +50,10 @@ func NewOutputRoomEventConsumer( consumer := internal.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputRoomEvent), Consumer: kafkaConsumer, - PartitionStore: store, + PartitionStore: appserviceDB, } s := &OutputRoomEventConsumer{ roomServerConsumer: &consumer, - db: store, asDB: appserviceDB, rsAPI: rsAPI, serverName: string(cfg.Matrix.ServerName), @@ -91,60 +87,13 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { return nil } - ev := output.NewRoomEvent.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "type": ev.Type(), - }).Info("appservice received an event from roomserver") - - missingEvents, err := s.lookupMissingStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev) - if err != nil { - return err - } - events := append(missingEvents, ev) + events := []gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} + events = append(events, output.NewRoomEvent.AddStateEvents...) // Send event to any relevant application services return s.filterRoomserverEvents(context.TODO(), events) } -// lookupMissingStateEvents looks up the state events that are added by a new event, -// and returns any not already present. -func (s *OutputRoomEventConsumer) lookupMissingStateEvents( - addsStateEventIDs []string, event gomatrixserverlib.HeaderedEvent, -) ([]gomatrixserverlib.HeaderedEvent, error) { - // Fast path if there aren't any new state events. - if len(addsStateEventIDs) == 0 { - return []gomatrixserverlib.HeaderedEvent{}, nil - } - - // Fast path if the only state event added is the event itself. - if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { - return []gomatrixserverlib.HeaderedEvent{}, nil - } - - result := []gomatrixserverlib.HeaderedEvent{} - missing := []string{} - for _, id := range addsStateEventIDs { - if id != event.EventID() { - // If the event isn't the current one, add it to the list of events - // to retrieve from the roomserver - missing = append(missing, id) - } - } - - // Request the missing events from the roomserver - eventReq := api.QueryEventsByIDRequest{EventIDs: missing} - var eventResp api.QueryEventsByIDResponse - if err := s.rsAPI.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { - return nil, err - } - - result = append(result, eventResp.Events...) - - return result, nil -} - // filterRoomserverEvents takes in events and decides whether any of them need // to be passed on to an external application service. It does this by checking // each namespace of each registered application service, and if there is a diff --git a/appservice/inthttp/client.go b/appservice/inthttp/client.go index acbd1211b..7e3cb208f 100644 --- a/appservice/inthttp/client.go +++ b/appservice/inthttp/client.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/matrix-org/dendrite/appservice/api" - internalHTTP "github.com/matrix-org/dendrite/internal/http" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/opentracing/opentracing-go" ) @@ -46,7 +46,7 @@ func (h *httpAppServiceQueryAPI) RoomAliasExists( defer span.Finish() apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // UserIDExists implements AppServiceQueryAPI @@ -59,5 +59,5 @@ func (h *httpAppServiceQueryAPI) UserIDExists( defer span.Finish() apiURL := h.appserviceURL + AppServiceUserIDExistsPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go index 1900635a8..009b7b5db 100644 --- a/appservice/inthttp/server.go +++ b/appservice/inthttp/server.go @@ -6,7 +6,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/util" ) @@ -14,7 +14,7 @@ import ( func AddRoutes(a api.AppServiceQueryAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( AppServiceRoomAliasExistsPath, - internal.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse { var request api.RoomAliasExistsRequest var response api.RoomAliasExistsResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -28,7 +28,7 @@ func AddRoutes(a api.AppServiceQueryAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( AppServiceUserIDExistsPath, - internal.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse { var request api.UserIDExistsRequest var response api.UserIDExistsResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { diff --git a/appservice/storage/interface.go b/appservice/storage/interface.go index 25d35af6c..735e2f90a 100644 --- a/appservice/storage/interface.go +++ b/appservice/storage/interface.go @@ -17,10 +17,12 @@ package storage import ( "context" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { + internal.PartitionStorer StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error) diff --git a/appservice/storage/postgres/storage.go b/appservice/storage/postgres/storage.go index aeead6ed3..03f331d64 100644 --- a/appservice/storage/postgres/storage.go +++ b/appservice/storage/postgres/storage.go @@ -21,20 +21,20 @@ import ( // Import postgres database driver _ "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) // Database stores events intended to be later sent to application services type Database struct { + sqlutil.PartitionOffsetStatements events eventsStatements txnID txnStatements db *sql.DB } // NewDatabase opens a new database -func NewDatabase(dataSourceName string, dbProperties internal.DbProperties) (*Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { var result Database var err error if result.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { @@ -43,6 +43,9 @@ func NewDatabase(dataSourceName string, dbProperties internal.DbProperties) (*Da if err = result.prepare(); err != nil { return nil, err } + if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil { + return nil, err + } return &result, nil } diff --git a/appservice/storage/sqlite3/storage.go b/appservice/storage/sqlite3/storage.go index 2238b3ff9..cb55c8d94 100644 --- a/appservice/storage/sqlite3/storage.go +++ b/appservice/storage/sqlite3/storage.go @@ -20,7 +20,6 @@ import ( "database/sql" // Import SQLite database driver - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" _ "github.com/mattn/go-sqlite3" @@ -28,6 +27,7 @@ import ( // Database stores events intended to be later sent to application services type Database struct { + sqlutil.PartitionOffsetStatements events eventsStatements txnID txnStatements db *sql.DB @@ -41,12 +41,15 @@ func NewDatabase(dataSourceName string) (*Database, error) { if err != nil { return nil, err } - if result.db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil); err != nil { + if result.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } if err = result.prepare(); err != nil { return nil, err } + if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil { + return nil, err + } return &result, nil } diff --git a/appservice/storage/storage.go b/appservice/storage/storage.go index 51ba1de53..c848d15d7 100644 --- a/appservice/storage/storage.go +++ b/appservice/storage/storage.go @@ -21,12 +21,12 @@ import ( "github.com/matrix-org/dendrite/appservice/storage/postgres" "github.com/matrix-org/dendrite/appservice/storage/sqlite3" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets DB connection parameters -func NewDatabase(dataSourceName string, dbProperties internal.DbProperties) (Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return postgres.NewDatabase(dataSourceName, dbProperties) diff --git a/appservice/storage/storage_wasm.go b/appservice/storage/storage_wasm.go index a6144b435..1d6c4b4a9 100644 --- a/appservice/storage/storage_wasm.go +++ b/appservice/storage/storage_wasm.go @@ -19,12 +19,12 @@ import ( "net/url" "github.com/matrix-org/dendrite/appservice/storage/sqlite3" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) func NewDatabase( dataSourceName string, - dbProperties internal.DbProperties, // nolint:unparam + dbProperties sqlutil.DbProperties, // nolint:unparam ) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { diff --git a/are-we-synapse-yet.list b/are-we-synapse-yet.list index c088c8b5e..f59f80675 100644 --- a/are-we-synapse-yet.list +++ b/are-we-synapse-yet.list @@ -97,8 +97,8 @@ rst PUT power_levels should not explode if the old power levels were empty rst Both GET and PUT work rct POST /rooms/:room_id/receipt can create receipts red POST /rooms/:room_id/read_markers can create read marker -med POST /media/v1/upload can create an upload -med GET /media/v1/download can fetch the value again +med POST /media/r0/upload can create an upload +med GET /media/r0/download can fetch the value again cap GET /capabilities is present and well formed for registered user cap GET /r0/capabilities is not public reg Register with a recaptcha diff --git a/are-we-synapse-yet.py b/are-we-synapse-yet.py index 5d5128479..30979a129 100755 --- a/are-we-synapse-yet.py +++ b/are-we-synapse-yet.py @@ -33,6 +33,7 @@ import sys test_mappings = { "nsp": "Non-Spec API", + "unk": "Unknown API (no group specified)", "f": "Federation", # flag to mark test involves federation "federation_apis": { @@ -214,7 +215,8 @@ def main(results_tap_path, verbose): # } }, "nonspec": { - "nsp": {} + "nsp": {}, + "unk": {} }, } with open(results_tap_path, "r") as f: @@ -225,7 +227,7 @@ def main(results_tap_path, verbose): name = test_result["name"] group_id = test_name_to_group_id.get(name) if not group_id: - raise Exception("The test '%s' doesn't have a group" % (name,)) + summary["nonspec"]["unk"][name] = test_result["ok"] if group_id == "nsp": summary["nonspec"]["nsp"][name] = test_result["ok"] elif group_id in test_mappings["federation_apis"]: diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index 3482e5018..b4c39ae38 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -18,17 +18,13 @@ package auth import ( "context" "crypto/rand" - "database/sql" "encoding/base64" "fmt" "net/http" "strings" - "github.com/matrix-org/dendrite/appservice/types" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) @@ -39,21 +35,13 @@ var tokenByteLength = 32 // DeviceDatabase represents a device database. type DeviceDatabase interface { // Look up the device matching the given access token. - GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error) + GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) } // AccountDatabase represents an account database. type AccountDatabase interface { // Look up the account matching the given localpart. - GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error) -} - -// Data contains information required to authenticate a request. -type Data struct { - AccountDB AccountDatabase - DeviceDB DeviceDatabase - // AppServices is the list of all registered AS - AppServices []config.ApplicationService + GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) } // VerifyUserFromRequest authenticates the HTTP request, @@ -62,8 +50,8 @@ type Data struct { // Note: For an AS user, AS dummy device is returned. // On failure returns an JSON error response which can be sent to the client. func VerifyUserFromRequest( - req *http.Request, data Data, -) (*authtypes.Device, *util.JSONResponse) { + req *http.Request, userAPI api.UserInternalAPI, +) (*api.Device, *util.JSONResponse) { // Try to find the Application Service user token, err := ExtractAccessToken(req) if err != nil { @@ -72,105 +60,31 @@ func VerifyUserFromRequest( JSON: jsonerror.MissingToken(err.Error()), } } - - // Search for app service with given access_token - var appService *config.ApplicationService - for _, as := range data.AppServices { - if as.ASToken == token { - appService = &as - break - } + var res api.QueryAccessTokenResponse + err = userAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{ + AccessToken: token, + AppServiceUserID: req.URL.Query().Get("user_id"), + }, &res) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") + jsonErr := jsonerror.InternalServerError() + return nil, &jsonErr } - - if appService != nil { - // Create a dummy device for AS user - dev := authtypes.Device{ - // Use AS dummy device ID - ID: types.AppServiceDeviceID, - // AS dummy device has AS's token. - AccessToken: token, - } - - userID := req.URL.Query().Get("user_id") - localpart, err := userutil.ParseUsernameParam(userID, nil) - if err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(err.Error()), - } - } - - if localpart != "" { // AS is masquerading as another user - // Verify that the user is registered - account, err := data.AccountDB.GetAccountByLocalpart(req.Context(), localpart) - // Verify that account exists & appServiceID matches - if err == nil && account.AppServiceID == appService.ID { - // Set the userID of dummy device - dev.UserID = userID - return &dev, nil - } - + if res.Err != nil { + if forbidden, ok := res.Err.(*api.ErrorForbidden); ok { return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Application service has not registered this user"), + JSON: jsonerror.Forbidden(forbidden.Message), } } - - // AS is not masquerading as any user, so use AS's sender_localpart - dev.UserID = appService.SenderLocalpart - return &dev, nil } - - // Try to find local user from device database - dev, devErr := verifyAccessToken(req, data.DeviceDB) - if devErr == nil { - return dev, verifyUserParameters(req) - } - - return nil, &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: jsonerror.UnknownToken("Unrecognized access token"), // nolint: misspell - } -} - -// verifyUserParameters ensures that a request coming from a regular user is not -// using any query parameters reserved for an application service -func verifyUserParameters(req *http.Request) *util.JSONResponse { - if req.URL.Query().Get("ts") != "" { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("parameter 'ts' not allowed without valid parameter 'access_token'"), - } - } - return nil -} - -// verifyAccessToken verifies that an access token was supplied in the given HTTP request -// and returns the device it corresponds to. Returns resErr (an error response which can be -// sent to the client) if the token is invalid or there was a problem querying the database. -func verifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *authtypes.Device, resErr *util.JSONResponse) { - token, err := ExtractAccessToken(req) - if err != nil { - resErr = &util.JSONResponse{ + if res.Device == nil { + return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.MissingToken(err.Error()), - } - return - } - device, err = deviceDB.GetDeviceByAccessToken(req.Context(), token) - if err != nil { - if err == sql.ErrNoRows { - resErr = &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: jsonerror.UnknownToken("Unknown token"), - } - } else { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByAccessToken failed") - jsonErr := jsonerror.InternalServerError() - resErr = &jsonErr + JSON: jsonerror.UnknownToken("Unknown token"), } } - return + return res.Device, nil } // GenerateAccessToken creates a new access token. Returns an error if failed to generate diff --git a/clientapi/auth/authtypes/device.go b/clientapi/auth/authtypes/device.go deleted file mode 100644 index 299eff036..000000000 --- a/clientapi/auth/authtypes/device.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package authtypes - -// Device represents a client's device (mobile, web, etc) -type Device struct { - ID string - UserID string - // The access_token granted to this device. - // This uniquely identifies the device from all other devices and clients. - AccessToken string - // The unique ID of the session identified by the access token. - // Can be used as a secure substitution in places where data needs to be - // associated with access tokens. - SessionID int64 - // TODO: display name, last used timestamp, keys, etc - DisplayName string -} diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 2780f367c..174eb1bf1 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -18,8 +18,6 @@ import ( "github.com/Shopify/sarama" "github.com/gorilla/mux" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/consumers" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/routing" @@ -28,6 +26,9 @@ import ( "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/transactions" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -41,12 +42,12 @@ func AddPublicRoutes( deviceDB devices.Database, accountsDB accounts.Database, federation *gomatrixserverlib.FederationClient, - keyRing *gomatrixserverlib.KeyRing, rsAPI roomserverAPI.RoomserverInternalAPI, eduInputAPI eduServerAPI.EDUServerInputAPI, asAPI appserviceAPI.AppServiceQueryAPI, transactionsCache *transactions.Cache, fsAPI federationSenderAPI.FederationSenderInternalAPI, + userAPI userapi.UserInternalAPI, ) { syncProducer := &producers.SyncAPIProducer{ Producer: producer, @@ -62,7 +63,7 @@ func AddPublicRoutes( routing.Setup( router, cfg, eduInputAPI, rsAPI, asAPI, - accountsDB, deviceDB, federation, *keyRing, + accountsDB, deviceDB, userAPI, federation, syncProducer, transactionsCache, fsAPI, ) } diff --git a/clientapi/consumers/roomserver.go b/clientapi/consumers/roomserver.go index bd8ac1dcd..beeda042b 100644 --- a/clientapi/consumers/roomserver.go +++ b/clientapi/consumers/roomserver.go @@ -18,10 +18,10 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/Shopify/sarama" @@ -84,63 +84,9 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { return nil } - ev := output.NewRoomEvent.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "type": ev.Type(), - }).Info("received event from roomserver") - - events, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev.Event) - if err != nil { - return err - } - - return s.db.UpdateMemberships(context.TODO(), events, output.NewRoomEvent.RemovesStateEventIDs) -} - -// lookupStateEvents looks up the state events that are added by a new event. -func (s *OutputRoomEventConsumer) lookupStateEvents( - addsStateEventIDs []string, event gomatrixserverlib.Event, -) ([]gomatrixserverlib.Event, error) { - // Fast path if there aren't any new state events. - if len(addsStateEventIDs) == 0 { - // If the event is a membership update (e.g. for a profile update), it won't - // show up in AddsStateEventIDs, so we need to add it manually - if event.Type() == "m.room.member" { - return []gomatrixserverlib.Event{event}, nil - } - return nil, nil - } - - // Fast path if the only state event added is the event itself. - if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { - return []gomatrixserverlib.Event{event}, nil - } - - result := []gomatrixserverlib.Event{} - missing := []string{} - for _, id := range addsStateEventIDs { - // Append the current event in the results if its ID is in the events list - if id == event.EventID() { - result = append(result, event) - } else { - // If the event isn't the current one, add it to the list of events - // to retrieve from the roomserver - missing = append(missing, id) - } - } - - // Request the missing events from the roomserver - eventReq := api.QueryEventsByIDRequest{EventIDs: missing} - var eventResp api.QueryEventsByIDResponse - if err := s.rsAPI.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { - return nil, err - } - - for _, headeredEvent := range eventResp.Events { - result = append(result, headeredEvent.Event) - } - - return result, nil + return s.db.UpdateMemberships( + context.TODO(), + gomatrixserverlib.UnwrapEventHeaders(output.NewRoomEvent.AddsState()), + output.NewRoomEvent.RemovesStateEventIDs, + ) } diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index 244a61dc2..6ab8eef28 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -17,9 +17,9 @@ package producers import ( "encoding/json" - "github.com/matrix-org/dendrite/internal" - "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal/eventutil" + log "github.com/sirupsen/logrus" ) // SyncAPIProducer produces events for the sync API server to consume @@ -32,7 +32,7 @@ type SyncAPIProducer struct { func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error { var m sarama.ProducerMessage - data := internal.AccountData{ + data := eventutil.AccountData{ RoomID: roomID, Type: dataType, } @@ -44,6 +44,11 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string m.Topic = string(p.Topic) m.Key = sarama.StringEncoder(userID) m.Value = sarama.ByteEncoder(value) + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": roomID, + "data_type": dataType, + }).Infof("Producing to topic '%s'", p.Topic) _, _, err = p.Producer.SendMessage(&m) return err diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index a5d53c326..68e0dc5da 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -19,10 +19,10 @@ import ( "io/ioutil" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -30,7 +30,7 @@ import ( // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} func GetAccountData( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *api.Device, userID string, roomID string, dataType string, ) util.JSONResponse { if userID != device.UserID { @@ -51,7 +51,7 @@ func GetAccountData( ); err == nil { return util.JSONResponse{ Code: http.StatusOK, - JSON: data, + JSON: data.Content, } } @@ -63,7 +63,7 @@ func GetAccountData( // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} func SaveAccountData( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *api.Device, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { if userID != device.UserID { diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 89f49d356..be7124828 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -24,14 +24,14 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverVersion "github.com/matrix-org/dendrite/roomserver/version" + "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -98,7 +98,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { // Validate creation_content fields defined in the spec by marshalling the // creation_content map into bytes and then unmarshalling the bytes into - // internal.CreateContent. + // eventutil.CreateContent. creationContentBytes, err := json.Marshal(r.CreationContent) if err != nil { @@ -135,7 +135,7 @@ type fledglingEvent struct { // CreateRoom implements /createRoom func CreateRoom( - req *http.Request, device *authtypes.Device, + req *http.Request, device *api.Device, cfg *config.Dendrite, accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, @@ -149,7 +149,7 @@ func CreateRoom( // createRoom implements /createRoom // nolint: gocyclo func createRoom( - req *http.Request, device *authtypes.Device, + req *http.Request, device *api.Device, cfg *config.Dendrite, roomID string, accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, @@ -279,25 +279,25 @@ func createRoom( eventsToMake := []fledglingEvent{ {"m.room.create", "", r.CreationContent}, {"m.room.member", userID, membershipContent}, - {"m.room.power_levels", "", internal.InitialPowerLevelsContent(userID)}, + {"m.room.power_levels", "", eventutil.InitialPowerLevelsContent(userID)}, {"m.room.join_rules", "", gomatrixserverlib.JoinRuleContent{JoinRule: joinRules}}, - {"m.room.history_visibility", "", internal.HistoryVisibilityContent{HistoryVisibility: historyVisibility}}, + {"m.room.history_visibility", "", eventutil.HistoryVisibilityContent{HistoryVisibility: historyVisibility}}, } if roomAlias != "" { // TODO: bit of a chicken and egg problem here as the alias doesn't exist and cannot until we have made the room. // This means we might fail creating the alias but say the canonical alias is something that doesn't exist. // m.room.aliases is handled when we call roomserver.SetRoomAlias - eventsToMake = append(eventsToMake, fledglingEvent{"m.room.canonical_alias", "", internal.CanonicalAlias{Alias: roomAlias}}) + eventsToMake = append(eventsToMake, fledglingEvent{"m.room.canonical_alias", "", eventutil.CanonicalAlias{Alias: roomAlias}}) } if r.GuestCanJoin { - eventsToMake = append(eventsToMake, fledglingEvent{"m.room.guest_access", "", internal.GuestAccessContent{GuestAccess: "can_join"}}) + eventsToMake = append(eventsToMake, fledglingEvent{"m.room.guest_access", "", eventutil.GuestAccessContent{GuestAccess: "can_join"}}) } eventsToMake = append(eventsToMake, r.InitialState...) if r.Name != "" { - eventsToMake = append(eventsToMake, fledglingEvent{"m.room.name", "", internal.NameContent{Name: r.Name}}) + eventsToMake = append(eventsToMake, fledglingEvent{"m.room.name", "", eventutil.NameContent{Name: r.Name}}) } if r.Topic != "" { - eventsToMake = append(eventsToMake, fledglingEvent{"m.room.topic", "", internal.TopicContent{Topic: r.Topic}}) + eventsToMake = append(eventsToMake, fledglingEvent{"m.room.topic", "", eventutil.TopicContent{Topic: r.Topic}}) } // TODO: invite events // TODO: 3pid invite events diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index fab0d68d0..7b8c951a1 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -19,9 +19,9 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -47,7 +47,7 @@ type devicesDeleteJSON struct { // GetDeviceByID handles /devices/{deviceID} func GetDeviceByID( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, deviceID string, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -78,7 +78,7 @@ func GetDeviceByID( // GetDevicesByLocalpart handles /devices func GetDevicesByLocalpart( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -110,7 +110,7 @@ func GetDevicesByLocalpart( // UpdateDeviceByID handles PUT on /devices/{deviceID} func UpdateDeviceByID( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, deviceID string, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -160,7 +160,7 @@ func UpdateDeviceByID( // DeleteDeviceById handles DELETE requests to /devices/{deviceId} func DeleteDeviceById( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, deviceID string, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -185,7 +185,7 @@ func DeleteDeviceById( // DeleteDevices handles POST requests to /delete_devices func DeleteDevices( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 3d4b5f5bc..0dc4d5605 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -18,12 +18,12 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -112,7 +112,7 @@ func DirectoryRoom( // TODO: Check if the user has the power level to set an alias func SetLocalAlias( req *http.Request, - device *authtypes.Device, + device *api.Device, alias string, cfg *config.Dendrite, aliasAPI roomserverAPI.RoomserverInternalAPI, @@ -188,7 +188,7 @@ func SetLocalAlias( // RemoveLocalAlias implements DELETE /directory/room/{roomAlias} func RemoveLocalAlias( req *http.Request, - device *authtypes.Device, + device *api.Device, alias string, aliasAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { diff --git a/clientapi/routing/filter.go b/clientapi/routing/filter.go index 505e09279..6520e6e40 100644 --- a/clientapi/routing/filter.go +++ b/clientapi/routing/filter.go @@ -17,17 +17,17 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} func GetFilter( - req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, filterID string, + req *http.Request, device *api.Device, accountDB accounts.Database, userID string, filterID string, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ @@ -64,7 +64,7 @@ type filterResponse struct { //PutFilter implements POST /_matrix/client/r0/user/{userId}/filter func PutFilter( - req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, + req *http.Request, device *api.Device, accountDB accounts.Database, userID string, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ diff --git a/clientapi/routing/getevent.go b/clientapi/routing/getevent.go index 3ca9d1b62..2a51db730 100644 --- a/clientapi/routing/getevent.go +++ b/clientapi/routing/getevent.go @@ -17,22 +17,21 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) type getEventRequest struct { req *http.Request - device *authtypes.Device + device *userapi.Device roomID string eventID string cfg *config.Dendrite federation *gomatrixserverlib.FederationClient - keyRing gomatrixserverlib.KeyRing requestedEvent gomatrixserverlib.Event } @@ -40,13 +39,12 @@ type getEventRequest struct { // https://matrix.org/docs/spec/client_server/r0.4.0.html#get-matrix-client-r0-rooms-roomid-event-eventid func GetEvent( req *http.Request, - device *authtypes.Device, + device *userapi.Device, roomID string, eventID string, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, federation *gomatrixserverlib.FederationClient, - keyRing gomatrixserverlib.KeyRing, ) util.JSONResponse { eventsReq := api.QueryEventsByIDRequest{ EventIDs: []string{eventID}, @@ -75,7 +73,6 @@ func GetEvent( eventID: eventID, cfg: cfg, federation: federation, - keyRing: keyRing, requestedEvent: requestedEvent, } diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index a3d676532..3871e4d4e 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -17,18 +17,18 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) func JoinRoomByIDOrAlias( req *http.Request, - device *authtypes.Device, + device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, accountDB accounts.Database, roomIDOrAlias string, @@ -43,9 +43,7 @@ func JoinRoomByIDOrAlias( // If content was provided in the request then incude that // in the request. It'll get used as a part of the membership // event content. - if err := httputil.UnmarshalJSONRequest(req, &joinReq.Content); err != nil { - return *err - } + _ = httputil.UnmarshalJSONRequest(req, &joinReq.Content) // Work out our localpart for the client profile request. localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index bd7696181..38cef118e 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -17,15 +17,15 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) func LeaveRoomByID( req *http.Request, - device *authtypes.Device, + device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, roomID string, ) util.JSONResponse { diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 5b357bd8d..860aa6d38 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -22,13 +22,13 @@ import ( "context" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -49,9 +49,8 @@ type loginIdentifier struct { type loginWithPasswordRequest struct { Identifier loginIdentifier `json:"identifier"` - Medium string `json:"medium"` // third-party only - Password string `json:"password"` // m.login.password only - Token string `json:"token"` // m.login.token only + User string `json:"user"` // deprecated in favour of identifier + Password string `json:"password"` // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") // Thus a pointer is needed to differentiate between the two InitialDisplayName *string `json:"initial_device_display_name"` @@ -88,9 +87,10 @@ func Login( JSON: passwordLogin(), } } else if req.Method == http.MethodPost { - var temp interface{} - var acc *authtypes.Account - resErr := httputil.UnmarshalJSONRequest(req, &temp) + var r passwordRequest + var acc *api.Account + var errJSON *util.JSONResponse + resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr } @@ -113,43 +113,23 @@ func Login( JSON: jsonerror.BadJSON("'user' must be supplied."), } } - - util.GetLogger(req.Context()).WithField("user", r.Identifier.User).Info("Processing login request") - - localpart, err := userutil.ParseUsernameParam(r.Identifier.User, &cfg.Matrix.ServerName) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(err.Error()), - } - } - - acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password) - if err != nil { - // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows - // but that would leak the existence of the user. - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"), - } - } - - case "m.login.token": - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Token login is not supported"), - } - - default: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("login identifier '" + r.Identifier.Type + "' not supported"), - } + } + acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.Identifier.User) + if errJSON != nil { + return *errJSON } default: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(fmt.Sprintf("Login type %q not supported", r.Type)), + // TODO: The below behaviour is deprecated but without it Riot iOS won't log in + if r.User != "" { + acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.User) + if errJSON != nil { + return *errJSON + } + } else { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("login identifier '" + r.Identifier.Type + "' not supported"), + } } } @@ -188,11 +168,40 @@ func getDevice( ctx context.Context, r loginWithPasswordRequest, deviceDB devices.Database, - acc *authtypes.Account, + acc *api.Account, token string, -) (dev *authtypes.Device, err error) { +) (dev *api.Device, err error) { dev, err = deviceDB.CreateDevice( ctx, acc.Localpart, r.DeviceID, token, r.InitialDisplayName, ) return } + +func (r *passwordRequest) processUsernamePasswordLoginRequest( + req *http.Request, accountDB accounts.Database, + cfg *config.Dendrite, username string, +) (acc *api.Account, errJSON *util.JSONResponse) { + util.GetLogger(req.Context()).WithField("user", username).Info("Processing login request") + + localpart, err := userutil.ParseUsernameParam(username, &cfg.Matrix.ServerName) + if err != nil { + errJSON = &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(err.Error()), + } + return + } + + acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password) + if err != nil { + // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows + // but that would leak the existence of the user. + errJSON = &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"), + } + return + } + + return +} diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index 26b7f117e..3ce47169e 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -17,16 +17,16 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // Logout handles POST /logout func Logout( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -47,7 +47,7 @@ func Logout( // LogoutAll handles POST /logout/all func LogoutAll( - req *http.Request, deviceDB devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 0b2c2bc5e..0d4b0d88d 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -22,14 +22,15 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -42,7 +43,7 @@ var errMissingUserID = errors.New("'user_id' must be supplied") // TODO: Can we improve the cyclo count here? Separate code paths for invites? // nolint:gocyclo func SendMembership( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *userapi.Device, roomID string, membership string, cfg *config.Dendrite, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -95,7 +96,7 @@ func SendMembership( Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error()), } - } else if err == internal.ErrRoomNoExists { + } else if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound(err.Error()), @@ -149,7 +150,7 @@ func SendMembership( func buildMembershipEvent( ctx context.Context, body threepid.MembershipRequest, accountDB accounts.Database, - device *authtypes.Device, + device *userapi.Device, membership, roomID string, isDirect bool, cfg *config.Dendrite, evTime time.Time, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, @@ -188,7 +189,7 @@ func buildMembershipEvent( return nil, err } - return internal.BuildEvent(ctx, &builder, cfg, evTime, rsAPI, nil) + return eventutil.BuildEvent(ctx, &builder, cfg, evTime, rsAPI, nil) } // loadProfile lookups the profile of a given user from the database and returns @@ -223,7 +224,7 @@ func loadProfile( // In the latter case, if there was an issue retrieving the user ID from the request body, // returns a JSONResponse with a corresponding error code and message. func getMembershipStateKey( - body threepid.MembershipRequest, device *authtypes.Device, membership string, + body threepid.MembershipRequest, device *userapi.Device, membership string, ) (stateKey string, reason string, err error) { if membership == gomatrixserverlib.Ban || membership == "unban" || membership == "kick" || membership == gomatrixserverlib.Invite { // If we're in this case, the state key is contained in the request body, @@ -245,7 +246,7 @@ func getMembershipStateKey( func checkAndProcessThreepid( req *http.Request, - device *authtypes.Device, + device *userapi.Device, body *threepid.MembershipRequest, cfg *config.Dendrite, rsAPI roomserverAPI.RoomserverInternalAPI, @@ -268,7 +269,7 @@ func checkAndProcessThreepid( Code: http.StatusBadRequest, JSON: jsonerror.NotTrusted(body.IDServer), } - } else if err == internal.ErrRoomNoExists { + } else if err == eventutil.ErrRoomNoExists { return inviteStored, &util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound(err.Error()), diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index 095a85c08..1c9800b66 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -15,14 +15,15 @@ package routing import ( + "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -35,9 +36,19 @@ type getJoinedRoomsResponse struct { JoinedRooms []string `json:"joined_rooms"` } +// https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members +type getJoinedMembersResponse struct { + Joined map[string]joinedMember `json:"joined"` +} + +type joinedMember struct { + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` +} + // GetMemberships implements GET /rooms/{roomId}/members func GetMemberships( - req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool, + req *http.Request, device *userapi.Device, roomID string, joinedOnly bool, _ *config.Dendrite, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { @@ -59,6 +70,22 @@ func GetMemberships( } } + if joinedOnly { + var res getJoinedMembersResponse + res.Joined = make(map[string]joinedMember) + for _, ev := range queryRes.JoinEvents { + var content joinedMember + if err := json.Unmarshal(ev.Content, &content); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content") + return jsonerror.InternalServerError() + } + res.Joined[ev.Sender] = content + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: res, + } + } return util.JSONResponse{ Code: http.StatusOK, JSON: getMembershipResponse{queryRes.JoinEvents}, @@ -67,7 +94,7 @@ func GetMemberships( func GetJoinedRooms( req *http.Request, - device *authtypes.Device, + device *userapi.Device, accountsDB accounts.Database, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 642a72887..7c2cd19bc 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -21,12 +21,13 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" @@ -42,7 +43,7 @@ func GetProfile( ) util.JSONResponse { profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) if err != nil { - if err == internal.ErrProfileNoExists { + if err == eventutil.ErrProfileNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), @@ -55,7 +56,7 @@ func GetProfile( return util.JSONResponse{ Code: http.StatusOK, - JSON: internal.ProfileResponse{ + JSON: eventutil.ProfileResponse{ AvatarURL: profile.AvatarURL, DisplayName: profile.DisplayName, }, @@ -70,7 +71,7 @@ func GetAvatarURL( ) util.JSONResponse { profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) if err != nil { - if err == internal.ErrProfileNoExists { + if err == eventutil.ErrProfileNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), @@ -83,7 +84,7 @@ func GetAvatarURL( return util.JSONResponse{ Code: http.StatusOK, - JSON: internal.AvatarURL{ + JSON: eventutil.AvatarURL{ AvatarURL: profile.AvatarURL, }, } @@ -92,7 +93,7 @@ func GetAvatarURL( // SetAvatarURL implements PUT /profile/{userID}/avatar_url // nolint:gocyclo func SetAvatarURL( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *userapi.Device, userID string, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -102,7 +103,7 @@ func SetAvatarURL( } } - var r internal.AvatarURL + var r eventutil.AvatarURL if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } @@ -184,7 +185,7 @@ func GetDisplayName( ) util.JSONResponse { profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) if err != nil { - if err == internal.ErrProfileNoExists { + if err == eventutil.ErrProfileNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), @@ -197,7 +198,7 @@ func GetDisplayName( return util.JSONResponse{ Code: http.StatusOK, - JSON: internal.DisplayName{ + JSON: eventutil.DisplayName{ DisplayName: profile.DisplayName, }, } @@ -206,7 +207,7 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname // nolint:gocyclo func SetDisplayName( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *userapi.Device, userID string, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -216,7 +217,7 @@ func SetDisplayName( } } - var r internal.DisplayName + var r eventutil.DisplayName if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } @@ -293,7 +294,7 @@ func SetDisplayName( // getProfile gets the full profile of a user by querying the database or a // remote homeserver. // Returns an error when something goes wrong or specifically -// internal.ErrProfileNoExists when the profile doesn't exist. +// eventutil.ErrProfileNoExists when the profile doesn't exist. func getProfile( ctx context.Context, accountDB accounts.Database, cfg *config.Dendrite, userID string, @@ -310,7 +311,7 @@ func getProfile( if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { - return nil, internal.ErrProfileNoExists + return nil, eventutil.ErrProfileNoExists } } @@ -365,7 +366,7 @@ func buildMembershipEvents( return nil, err } - event, err := internal.BuildEvent(ctx, &builder, cfg, evTime, rsAPI, nil) + event, err := eventutil.BuildEvent(ctx, &builder, cfg, evTime, rsAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index ba6c57566..69ebdfd70 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -33,15 +33,15 @@ import ( "time" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/internal" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/util" @@ -136,7 +136,7 @@ type registerRequest struct { DeviceID *string `json:"device_id"` // Prevent this user from logging in - InhibitLogin internal.WeakBoolean `json:"inhibit_login"` + InhibitLogin eventutil.WeakBoolean `json:"inhibit_login"` // Application Services place Type in the root of their registration // request, whereas clients place it in the authDict struct. @@ -440,8 +440,8 @@ func validateApplicationService( // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register func Register( req *http.Request, + userAPI userapi.UserInternalAPI, accountDB accounts.Database, - deviceDB devices.Database, cfg *config.Dendrite, ) util.JSONResponse { var r registerRequest @@ -450,7 +450,7 @@ func Register( return *resErr } if req.URL.Query().Get("kind") == "guest" { - return handleGuestRegistration(req, r, cfg, accountDB, deviceDB) + return handleGuestRegistration(req, r, cfg, userAPI) } // Retrieve or generate the sessionID @@ -506,21 +506,19 @@ func Register( "session_id": r.Auth.Session, }).Info("Processing registration request") - resp := handleRegistrationFlow(req, r, sessionID, cfg, accountDB, deviceDB) - j, _ := json.MarshalIndent(resp, "", " ") - fmt.Println("ERROR!") - fmt.Println(string(j)) - return resp + return handleRegistrationFlow(req, r, sessionID, cfg, userAPI) } func handleGuestRegistration( req *http.Request, r registerRequest, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { - acc, err := accountDB.CreateGuestAccount(req.Context()) + var res userapi.PerformAccountCreationResponse + err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeGuest, + }, &res) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -529,8 +527,8 @@ func handleGuestRegistration( } token, err := tokens.GenerateLoginToken(tokens.TokenOptions{ ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(), - ServerName: string(acc.ServerName), - UserID: acc.UserID, + ServerName: string(res.Account.ServerName), + UserID: res.Account.UserID, }) if err != nil { @@ -540,7 +538,12 @@ func handleGuestRegistration( } } //we don't allow guests to specify their own device_id - dev, err := deviceDB.CreateDevice(req.Context(), acc.Localpart, nil, token, r.InitialDisplayName) + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ + Localpart: res.Account.Localpart, + DeviceDisplayName: r.InitialDisplayName, + AccessToken: token, + }, &devRes) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -550,10 +553,10 @@ func handleGuestRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: dev.UserID, - AccessToken: dev.AccessToken, - HomeServer: acc.ServerName, - DeviceID: dev.ID, + UserID: devRes.Device.UserID, + AccessToken: devRes.Device.AccessToken, + HomeServer: res.Account.ServerName, + DeviceID: devRes.Device.ID, }, } } @@ -566,8 +569,7 @@ func handleRegistrationFlow( r registerRequest, sessionID string, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { // TODO: Shared secret registration (create new user scripts) // TODO: Enable registration config flag @@ -618,7 +620,7 @@ func handleRegistrationFlow( // by whether the request contains an access token. if err == nil { return handleApplicationServiceRegistration( - accessToken, err, req, r, cfg, accountDB, deviceDB, + accessToken, err, req, r, cfg, userAPI, ) } @@ -629,7 +631,7 @@ func handleRegistrationFlow( // don't need a condition on that call since the registration is clearly // stated as being AS-related. return handleApplicationServiceRegistration( - accessToken, err, req, r, cfg, accountDB, deviceDB, + accessToken, err, req, r, cfg, userAPI, ) case authtypes.LoginTypeDummy: @@ -648,7 +650,7 @@ func handleRegistrationFlow( // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), - req, r, sessionID, cfg, accountDB, deviceDB) + req, r, sessionID, cfg, userAPI) } // handleApplicationServiceRegistration handles the registration of an @@ -665,8 +667,7 @@ func handleApplicationServiceRegistration( req *http.Request, r registerRequest, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { // Check if we previously had issues extracting the access token from the // request. @@ -690,7 +691,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), accountDB, deviceDB, r.Username, "", appserviceID, + req.Context(), userAPI, r.Username, "", appserviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -704,13 +705,12 @@ func checkAndCompleteFlow( r registerRequest, sessionID string, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), accountDB, deviceDB, r.Username, r.Password, "", + req.Context(), userAPI, r.Username, r.Password, "", r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -727,8 +727,7 @@ func checkAndCompleteFlow( // LegacyRegister process register requests from the legacy v1 API func LegacyRegister( req *http.Request, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, cfg *config.Dendrite, ) util.JSONResponse { var r legacyRegisterRequest @@ -763,10 +762,10 @@ func LegacyRegister( return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") } - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) + return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) + return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) default: return util.JSONResponse{ Code: http.StatusNotImplemented, @@ -812,10 +811,9 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u // not all func completeRegistration( ctx context.Context, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, username, password, appserviceID string, - inhibitLogin internal.WeakBoolean, + inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, ) util.JSONResponse { if username == "" { @@ -832,9 +830,16 @@ func completeRegistration( } } - acc, err := accountDB.CreateAccount(ctx, username, password, appserviceID) + var accRes userapi.PerformAccountCreationResponse + err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ + AppServiceID: appserviceID, + Localpart: username, + Password: password, + AccountType: userapi.AccountTypeUser, + OnConflict: userapi.ConflictAbort, + }, &accRes) if err != nil { - if errors.Is(err, internal.ErrUserExists) { // user already exists + if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.UserInUse("Desired user ID is already taken."), @@ -855,8 +860,8 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: userutil.MakeUserID(username, acc.ServerName), - HomeServer: acc.ServerName, + UserID: userutil.MakeUserID(username, accRes.Account.ServerName), + HomeServer: accRes.Account.ServerName, }, } } @@ -869,7 +874,13 @@ func completeRegistration( } } - dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName) + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + Localpart: username, + AccessToken: token, + DeviceDisplayName: displayName, + DeviceID: deviceID, + }, &devRes) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -880,10 +891,10 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: dev.UserID, - AccessToken: dev.AccessToken, - HomeServer: acc.ServerName, - DeviceID: dev.ID, + UserID: devRes.Device.UserID, + AccessToken: devRes.Device.AccessToken, + HomeServer: accRes.Account.ServerName, + DeviceID: devRes.Device.ID, }, } } diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index 5c68668d0..b1cfcca86 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -20,11 +20,11 @@ import ( "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -41,7 +41,7 @@ func newTag() gomatrix.TagContent { func GetTags( req *http.Request, accountDB accounts.Database, - device *authtypes.Device, + device *api.Device, userID string, roomID string, syncProducer *producers.SyncAPIProducer, @@ -79,7 +79,7 @@ func GetTags( func PutTag( req *http.Request, accountDB accounts.Database, - device *authtypes.Device, + device *api.Device, userID string, roomID string, tag string, @@ -139,7 +139,7 @@ func PutTag( func DeleteTag( req *http.Request, accountDB accounts.Database, - device *authtypes.Device, + device *api.Device, userID string, roomID string, tag string, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 024707759..41c7fb18e 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -21,18 +21,17 @@ import ( "github.com/gorilla/mux" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/transactions" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -54,15 +53,15 @@ func Setup( asAPI appserviceAPI.AppServiceQueryAPI, accountDB accounts.Database, deviceDB devices.Database, + userAPI api.UserInternalAPI, federation *gomatrixserverlib.FederationClient, - keyRing gomatrixserverlib.KeyRing, syncProducer *producers.SyncAPIProducer, transactionsCache *transactions.Cache, federationSender federationSenderAPI.FederationSenderInternalAPI, ) { publicAPIMux.Handle("/client/versions", - internal.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse { return util.JSONResponse{ Code: http.StatusOK, JSON: struct { @@ -81,20 +80,14 @@ func Setup( v1mux := publicAPIMux.PathPrefix(pathPrefixV1).Subrouter() unstableMux := publicAPIMux.PathPrefix(pathPrefixUnstable).Subrouter() - authData := auth.Data{ - AccountDB: accountDB, - DeviceDB: deviceDB, - AppServices: cfg.Derived.ApplicationServices, - } - r0mux.Handle("/createRoom", - internal.MakeAuthAPI("createRoom", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/join/{roomIDOrAlias}", - internal.MakeAuthAPI(gomatrixserverlib.Join, authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -104,13 +97,13 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/joined_rooms", - internal.MakeAuthAPI("joined_rooms", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return GetJoinedRooms(req, device, accountDB) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/leave", - internal.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -120,8 +113,8 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|invite)}", - internal.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -129,8 +122,8 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/send/{eventType}", - internal.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -138,8 +131,8 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", - internal.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -149,42 +142,44 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/event/{eventID}", - internal.MakeAuthAPI("rooms_get_event", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, federation, keyRing) + return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, federation) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state", internal.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } return OnIncomingStateRequest(req.Context(), rsAPI, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type}", internal.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + r0mux.Handle("/rooms/{roomID}/state/{type}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return OnIncomingStateTypeRequest(req.Context(), rsAPI, vars["roomID"], vars["type"], "") + eventFormat := req.URL.Query().Get("format") == "event" + return OnIncomingStateTypeRequest(req.Context(), rsAPI, vars["roomID"], vars["type"], "", eventFormat) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", internal.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - return OnIncomingStateTypeRequest(req.Context(), rsAPI, vars["roomID"], vars["type"], vars["stateKey"]) + eventFormat := req.URL.Query().Get("format") == "event" + return OnIncomingStateTypeRequest(req.Context(), rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) })).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", - internal.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -199,8 +194,8 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", - internal.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -209,21 +204,21 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/register", internal.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - return Register(req, accountDB, deviceDB, cfg) + r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + return Register(req, userAPI, accountDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) - v1mux.Handle("/register", internal.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - return LegacyRegister(req, accountDB, deviceDB, cfg) + v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + return LegacyRegister(req, userAPI, cfg) })).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/register/available", internal.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { return RegisterAvailable(req, cfg, accountDB) })).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/directory/room/{roomAlias}", - internal.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -232,8 +227,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/directory/room/{roomAlias}", - internal.MakeAuthAPI("directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -242,8 +237,8 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/directory/room/{roomAlias}", - internal.MakeAuthAPI("directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -252,20 +247,20 @@ func Setup( ).Methods(http.MethodDelete, http.MethodOptions) r0mux.Handle("/logout", - internal.MakeAuthAPI("logout", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return Logout(req, deviceDB, device) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/logout/all", - internal.MakeAuthAPI("logout", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return LogoutAll(req, deviceDB, device) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/typing/{userID}", - internal.MakeAuthAPI("rooms_typing", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -274,8 +269,8 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/sendToDevice/{eventType}/{txnID}", - internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -288,8 +283,8 @@ func Setup( // rather than r0. It's an exact duplicate of the above handler. // TODO: Remove this if/when sytest is fixed! unstableMux.Handle("/sendToDevice/{eventType}/{txnID}", - internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -299,7 +294,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/account/whoami", - internal.MakeAuthAPI("whoami", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return Whoami(req, device) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -307,20 +302,20 @@ func Setup( // Stub endpoints required by Riot r0mux.Handle("/login", - internal.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { return Login(req, accountDB, deviceDB, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) r0mux.Handle("/auth/{authType}/fallback/web", - internal.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) return AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) r0mux.Handle("/pushrules/", - internal.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { // TODO: Implement push rules API res := json.RawMessage(`{ "global": { @@ -339,8 +334,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/user/{userId}/filter", - internal.MakeAuthAPI("put_filter", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -349,8 +344,8 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/user/{userId}/filter/{filterId}", - internal.MakeAuthAPI("get_filter", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -361,8 +356,8 @@ func Setup( // Riot user settings r0mux.Handle("/profile/{userID}", - internal.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -371,8 +366,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/profile/{userID}/avatar_url", - internal.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -381,8 +376,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/profile/{userID}/avatar_url", - internal.MakeAuthAPI("profile_avatar_url", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -393,8 +388,8 @@ func Setup( // PUT requests, so we need to allow this method r0mux.Handle("/profile/{userID}/displayname", - internal.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -403,8 +398,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/profile/{userID}/displayname", - internal.MakeAuthAPI("profile_displayname", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -415,32 +410,32 @@ func Setup( // PUT requests, so we need to allow this method r0mux.Handle("/account/3pid", - internal.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return GetAssociated3PIDs(req, accountDB, device) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/account/3pid", - internal.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return CheckAndSave3PIDAssociation(req, accountDB, device, cfg) }), ).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/account/3pid/delete", - internal.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return Forget3PID(req, accountDB) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", - internal.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { return RequestEmailToken(req, accountDB, cfg) }), ).Methods(http.MethodPost, http.MethodOptions) // Riot logs get flooded unless this is handled r0mux.Handle("/presence/{userID}/status", - internal.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { // TODO: Set presence (probably the responsibility of a presence server not clientapi) return util.JSONResponse{ Code: http.StatusOK, @@ -450,13 +445,13 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/voip/turnServer", - internal.MakeAuthAPI("turn_server", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return RequestTurnServer(req, device, cfg) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/thirdparty/protocols", - internal.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { // TODO: Return the third party protcols return util.JSONResponse{ Code: http.StatusOK, @@ -466,7 +461,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/initialSync", - internal.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse { // TODO: Allow people to peek into rooms. return util.JSONResponse{ Code: http.StatusForbidden, @@ -476,8 +471,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/user/{userID}/account_data/{type}", - internal.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -486,8 +481,8 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", - internal.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -496,8 +491,8 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/user/{userID}/account_data/{type}", - internal.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -506,8 +501,8 @@ func Setup( ).Methods(http.MethodGet) r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", - internal.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -516,8 +511,8 @@ func Setup( ).Methods(http.MethodGet) r0mux.Handle("/rooms/{roomID}/members", - internal.MakeAuthAPI("rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -526,8 +521,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/joined_members", - internal.MakeAuthAPI("rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -536,21 +531,21 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/read_markers", - internal.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { // TODO: return the read_markers. return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/devices", - internal.MakeAuthAPI("get_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return GetDevicesByLocalpart(req, deviceDB, device) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/devices/{deviceID}", - internal.MakeAuthAPI("get_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -559,8 +554,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/devices/{deviceID}", - internal.MakeAuthAPI("device_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -569,8 +564,8 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/devices/{deviceID}", - internal.MakeAuthAPI("delete_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -579,14 +574,14 @@ func Setup( ).Methods(http.MethodDelete, http.MethodOptions) r0mux.Handle("/delete_devices", - internal.MakeAuthAPI("delete_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return DeleteDevices(req, deviceDB, device) }), ).Methods(http.MethodPost, http.MethodOptions) // Stub implementations for sytest r0mux.Handle("/events", - internal.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, "start": "", @@ -596,7 +591,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/initialSync", - internal.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", }} @@ -604,8 +599,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/user/{userId}/rooms/{roomId}/tags", - internal.MakeAuthAPI("get_tags", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -614,8 +609,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", - internal.MakeAuthAPI("put_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -624,8 +619,8 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", - internal.MakeAuthAPI("delete_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -634,7 +629,7 @@ func Setup( ).Methods(http.MethodDelete, http.MethodOptions) r0mux.Handle("/capabilities", - internal.MakeAuthAPI("capabilities", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { return GetCapabilities(req, rsAPI) }), ).Methods(http.MethodGet) diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 5d5507e85..d8936f750 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -17,13 +17,13 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -41,7 +41,7 @@ type sendEventResponse struct { // /rooms/{roomID}/state/{eventType}/{stateKey} func SendEvent( req *http.Request, - device *authtypes.Device, + device *userapi.Device, roomID, eventType string, txnID, stateKey *string, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, @@ -110,7 +110,7 @@ func SendEvent( func generateSendEvent( req *http.Request, - device *authtypes.Device, + device *userapi.Device, roomID, eventType string, stateKey *string, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, @@ -146,8 +146,8 @@ func generateSendEvent( } var queryRes api.QueryLatestEventsAndStateResponse - e, err := internal.BuildEvent(req.Context(), &builder, cfg, evTime, rsAPI, &queryRes) - if err == internal.ErrRoomNoExists { + e, err := eventutil.BuildEvent(req.Context(), &builder, cfg, evTime, rsAPI, &queryRes) + if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Room does not exist"), @@ -158,7 +158,7 @@ func generateSendEvent( JSON: jsonerror.BadJSON(e.Error()), } } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("internal.BuildEvent failed") + util.GetLogger(req.Context()).WithError(err).Error("eventutil.BuildEvent failed") resErr := jsonerror.InternalServerError() return nil, &resErr } diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go index dc0a6572b..768e8e0e7 100644 --- a/clientapi/routing/sendtodevice.go +++ b/clientapi/routing/sendtodevice.go @@ -16,18 +16,18 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/internal/transactions" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) // SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} // sends the device events to the EDU Server func SendToDevice( - req *http.Request, device *authtypes.Device, + req *http.Request, device *userapi.Device, eduAPI api.EDUServerInputAPI, txnCache *transactions.Cache, eventType string, txnID *string, diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 2eae16582..9b6a0b39b 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -16,12 +16,12 @@ import ( "database/sql" "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/eduserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/util" ) @@ -33,7 +33,7 @@ type typingContentJSON struct { // SendTyping handles PUT /rooms/{roomID}/typing/{userID} // sends the typing events to client API typingProducer func SendTyping( - req *http.Request, device *authtypes.Device, roomID string, + req *http.Request, device *userapi.Device, roomID string, userID string, accountDB accounts.Database, eduAPI api.EDUServerInputAPI, ) util.JSONResponse { diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index e3e5bdb6d..2ec7a33f3 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -98,7 +98,8 @@ func OnIncomingStateRequest(ctx context.Context, rsAPI api.RoomserverInternalAPI // /rooms/{roomID}/state/{type}/{statekey} request. It will look in current // state to see if there is an event with that type and state key, if there // is then (by default) we return the content, otherwise a 404. -func OnIncomingStateTypeRequest(ctx context.Context, rsAPI api.RoomserverInternalAPI, roomID string, evType, stateKey string) util.JSONResponse { +// If eventFormat=true, sends the whole event else just the content. +func OnIncomingStateTypeRequest(ctx context.Context, rsAPI api.RoomserverInternalAPI, roomID, evType, stateKey string, eventFormat bool) util.JSONResponse { // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) util.GetLogger(ctx).WithFields(log.Fields{ @@ -134,8 +135,15 @@ func OnIncomingStateTypeRequest(ctx context.Context, rsAPI api.RoomserverInterna ClientEvent: gomatrixserverlib.HeaderedToClientEvent(stateRes.StateEvents[0], gomatrixserverlib.FormatAll), } + var res interface{} + if eventFormat { + res = stateEvent + } else { + res = stateEvent.Content + } + return util.JSONResponse{ Code: http.StatusOK, - JSON: stateEvent.Content, + JSON: res, } } diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 49a9f3047..e7aaadf54 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -18,11 +18,12 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -84,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *api.Device, cfg *config.Dendrite, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest @@ -148,7 +149,7 @@ func CheckAndSave3PIDAssociation( // GetAssociated3PIDs implements GET /account/3pid func GetAssociated3PIDs( - req *http.Request, accountDB accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { diff --git a/clientapi/routing/voip.go b/clientapi/routing/voip.go index 212d9b0a4..046e87811 100644 --- a/clientapi/routing/voip.go +++ b/clientapi/routing/voip.go @@ -22,16 +22,16 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" ) // RequestTurnServer implements: // GET /voip/turnServer -func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg *config.Dendrite) util.JSONResponse { +func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Dendrite) util.JSONResponse { turnConfig := cfg.TURN // TODO Guest Support diff --git a/clientapi/routing/whoami.go b/clientapi/routing/whoami.go index 840bcb5f2..26280f6cc 100644 --- a/clientapi/routing/whoami.go +++ b/clientapi/routing/whoami.go @@ -15,7 +15,7 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) @@ -26,7 +26,7 @@ type whoamiResponse struct { // Whoami implements `/account/whoami` which enables client to query their account user id. // https://matrix.org/docs/spec/client_server/r0.3.0.html#get-matrix-client-r0-account-whoami -func Whoami(req *http.Request, device *authtypes.Device) util.JSONResponse { +func Whoami(req *http.Request, device *api.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusOK, JSON: whoamiResponse{UserID: device.UserID}, diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index b0df0dd4d..c308cb1f4 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -25,10 +25,11 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" ) @@ -85,7 +86,7 @@ var ( // can be emitted. func CheckAndProcessInvite( ctx context.Context, - device *authtypes.Device, body *MembershipRequest, cfg *config.Dendrite, + device *userapi.Device, body *MembershipRequest, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, db accounts.Database, membership string, roomID string, evTime time.Time, @@ -136,7 +137,7 @@ func CheckAndProcessInvite( // Returns an error if a check or a request failed. func queryIDServer( ctx context.Context, - db accounts.Database, cfg *config.Dendrite, device *authtypes.Device, + db accounts.Database, cfg *config.Dendrite, device *userapi.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { if err = isTrusted(body.IDServer, cfg); err != nil { @@ -205,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( ctx context.Context, - db accounts.Database, cfg *config.Dendrite, device *authtypes.Device, + db accounts.Database, cfg *config.Dendrite, device *userapi.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name @@ -329,7 +330,7 @@ func checkIDServerSignatures( func emit3PIDInviteEvent( ctx context.Context, body *MembershipRequest, res *idServerStoreInviteResponse, - device *authtypes.Device, roomID string, cfg *config.Dendrite, + device *userapi.Device, roomID string, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, evTime time.Time, ) error { @@ -353,7 +354,7 @@ func emit3PIDInviteEvent( } queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := internal.BuildEvent(ctx, builder, cfg, evTime, rsAPI, &queryRes) + event, err := eventutil.BuildEvent(ctx, builder, cfg, evTime, rsAPI, &queryRes) if err != nil { return err } diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 9c1e45932..ff022ec3c 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -20,8 +20,8 @@ import ( "fmt" "os" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/cmd/create-room-events/main.go b/cmd/create-room-events/main.go index ebce953ce..afe974643 100644 --- a/cmd/create-room-events/main.go +++ b/cmd/create-room-events/main.go @@ -47,6 +47,7 @@ var ( userID = flag.String("user-id", "@userid:$SERVER_NAME", "The user ID to use as the event sender") messageCount = flag.Int("message-count", 10, "The number of m.room.messsage events to generate") format = flag.String("Format", "InputRoomEvent", "The output format to use for the messages: InputRoomEvent or Event") + ver = flag.String("version", string(gomatrixserverlib.RoomVersionV1), "Room version to generate events as") ) // By default we use a private key of 0. @@ -109,7 +110,7 @@ func buildAndOutput() gomatrixserverlib.EventReference { event, err := b.Build( now, name, key, privateKey, - gomatrixserverlib.RoomVersionV1, + gomatrixserverlib.RoomVersion(*ver), ) if err != nil { panic(err) @@ -127,7 +128,7 @@ func writeEvent(event gomatrixserverlib.Event) { if *format == "InputRoomEvent" { var ire api.InputRoomEvent ire.Kind = api.KindNew - ire.Event = event.Headered(gomatrixserverlib.RoomVersionV1) + ire.Event = event.Headered(gomatrixserverlib.RoomVersion(*ver)) authEventIDs := []string{} for _, ref := range b.AuthEvents.([]gomatrixserverlib.EventReference) { authEventIDs = append(authEventIDs, ref.EventID) diff --git a/cmd/dendrite-appservice-server/main.go b/cmd/dendrite-appservice-server/main.go index fe00a117d..6719d0471 100644 --- a/cmd/dendrite-appservice-server/main.go +++ b/cmd/dendrite-appservice-server/main.go @@ -16,19 +16,18 @@ package main import ( "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "AppServiceAPI", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "AppServiceAPI", true) defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() + userAPI := base.UserAPIClient() rsAPI := base.RoomserverHTTPClient() - intAPI := appservice.NewInternalAPI(base, accountDB, deviceDB, rsAPI) + intAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) appservice.AddInternalRoutes(base.InternalAPIMux, intAPI) base.SetupAndServeHTTP(string(base.Cfg.Bind.AppServiceAPI), string(base.Cfg.Listen.AppServiceAPI)) diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go index 9cf57ef68..fe5f30a0e 100644 --- a/cmd/dendrite-client-api-server/main.go +++ b/cmd/dendrite-client-api-server/main.go @@ -16,31 +16,29 @@ package main import ( "github.com/matrix-org/dendrite/clientapi" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/internal/transactions" ) func main() { - cfg := basecomponent.ParseFlags(false) + cfg := setup.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "ClientAPI", true) + base := setup.NewBaseDendrite(cfg, "ClientAPI", true) defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() federation := base.CreateFederationClient() - serverKeyAPI := base.ServerKeyAPIClient() - keyRing := serverKeyAPI.KeyRing() - asQuery := base.AppserviceHTTPClient() rsAPI := base.RoomserverHTTPClient() fsAPI := base.FederationSenderHTTPClient() eduInputAPI := base.EDUServerClient() + userAPI := base.UserAPIClient() clientapi.AddPublicRoutes( - base.PublicAPIMux, base.Cfg, base.KafkaConsumer, base.KafkaProducer, deviceDB, accountDB, federation, keyRing, - rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, + base.PublicAPIMux, base.Cfg, base.KafkaConsumer, base.KafkaProducer, deviceDB, accountDB, federation, + rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI, ) base.SetupAndServeHTTP(string(base.Cfg.Bind.ClientAPI), string(base.Cfg.Listen.ClientAPI)) diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index b7cbcdc9c..356ab5a7f 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -32,11 +32,12 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-libp2p/storage" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/federationsender" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/serverkeyapi" + "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/eduserver/cache" @@ -79,6 +80,17 @@ func createFederationClient( ) } +func createClient( + base *P2PDendrite, +) *gomatrixserverlib.Client { + tr := &http.Transport{} + tr.RegisterProtocol( + "matrix", + p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")), + ) + return gomatrixserverlib.NewClientWithTransport(tr) +} + func main() { instanceName := flag.String("name", "dendrite-p2p", "the name of this P2P demo instance") instancePort := flag.Int("port", 8080, "the port that the client API will listen on") @@ -101,6 +113,7 @@ func main() { } cfg := config.Dendrite{} + cfg.SetDefaults() cfg.Matrix.ServerName = "p2p" cfg.Matrix.PrivateKey = privKey cfg.Matrix.KeyID = gomatrixserverlib.KeyID(fmt.Sprintf("ed25519:%s", *instanceName)) @@ -128,6 +141,7 @@ func main() { accountDB := base.Base.CreateAccountsDB() deviceDB := base.Base.CreateDeviceDB() federation := createFederationClient(base) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil) serverKeyAPI := serverkeyapi.NewInternalAPI( base.Base.Cfg, federation, base.Base.Caches, @@ -141,9 +155,9 @@ func main() { &base.Base, keyRing, federation, ) eduInputAPI := eduserver.NewInternalAPI( - &base.Base, cache.New(), deviceDB, + &base.Base, cache.New(), userAPI, ) - asAPI := appservice.NewInternalAPI(&base.Base, accountDB, deviceDB, rsAPI) + asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI) fsAPI := federationsender.NewInternalAPI( &base.Base, federation, rsAPI, keyRing, ) @@ -157,6 +171,7 @@ func main() { Config: base.Base.Cfg, AccountDB: accountDB, DeviceDB: deviceDB, + Client: createClient(base), FedClient: federation, KeyRing: keyRing, KafkaConsumer: base.Base.KafkaConsumer, @@ -167,13 +182,14 @@ func main() { FederationSenderAPI: fsAPI, RoomserverAPI: rsAPI, ServerKeyAPI: serverKeyAPI, + UserAPI: userAPI, PublicRoomsDB: publicRoomsDB, } monolith.AddAllPublicRoutes(base.Base.PublicAPIMux) - internal.SetupHTTPAPI( - http.DefaultServeMux, + httputil.SetupHTTPAPI( + base.Base.BaseMux, base.Base.PublicAPIMux, base.Base.InternalAPIMux, &cfg, @@ -184,7 +200,7 @@ func main() { go func() { httpBindAddr := fmt.Sprintf(":%d", *instancePort) logrus.Info("Listening on ", httpBindAddr) - logrus.Fatal(http.ListenAndServe(httpBindAddr, nil)) + logrus.Fatal(http.ListenAndServe(httpBindAddr, base.Base.BaseMux)) }() // Expose the matrix APIs also via libp2p if base.LibP2P != nil { @@ -197,7 +213,7 @@ func main() { defer func() { logrus.Fatal(listener.Close()) }() - logrus.Fatal(http.Serve(listener, nil)) + logrus.Fatal(http.Serve(listener, base.Base.BaseMux)) }() } diff --git a/cmd/dendrite-demo-libp2p/p2pdendrite.go b/cmd/dendrite-demo-libp2p/p2pdendrite.go index b7c5c66b0..4270143f5 100644 --- a/cmd/dendrite-demo-libp2p/p2pdendrite.go +++ b/cmd/dendrite-demo-libp2p/p2pdendrite.go @@ -22,7 +22,7 @@ import ( pstore "github.com/libp2p/go-libp2p-core/peerstore" record "github.com/libp2p/go-libp2p-record" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/libp2p/go-libp2p" circuit "github.com/libp2p/go-libp2p-circuit" @@ -39,7 +39,7 @@ import ( // P2PDendrite is a Peer-to-Peer variant of BaseDendrite. type P2PDendrite struct { - Base basecomponent.BaseDendrite + Base setup.BaseDendrite // Store our libp2p object so that we can make outgoing connections from it // later @@ -54,7 +54,7 @@ type P2PDendrite struct { // The componentName is used for logging purposes, and should be a friendly name // of the component running, e.g. SyncAPI. func NewP2PDendrite(cfg *config.Dendrite, componentName string) *P2PDendrite { - baseDendrite := basecomponent.NewBaseDendrite(cfg, componentName, false) + baseDendrite := setup.NewBaseDendrite(cfg, componentName, false) ctx, cancel := context.WithCancel(context.Background()) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 11f6d9970..2076aa135 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -30,12 +30,12 @@ import ( "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationsender" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/basecomponent" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -91,7 +91,7 @@ func main() { panic(err) } - base := basecomponent.NewBaseDendrite(cfg, "Monolith", false) + base := setup.NewBaseDendrite(cfg, "Monolith", false) defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() @@ -101,16 +101,18 @@ func main() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil) + rsComponent := roomserver.NewInternalAPI( base, keyRing, federation, ) rsAPI := rsComponent eduInputAPI := eduserver.NewInternalAPI( - base, cache.New(), deviceDB, + base, cache.New(), userAPI, ) - asAPI := appservice.NewInternalAPI(base, accountDB, deviceDB, rsAPI) + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) fsAPI := federationsender.NewInternalAPI( base, federation, rsAPI, keyRing, @@ -129,6 +131,7 @@ func main() { Config: base.Cfg, AccountDB: accountDB, DeviceDB: deviceDB, + Client: createClient(ygg), FedClient: federation, KeyRing: keyRing, KafkaConsumer: base.KafkaConsumer, @@ -138,20 +141,34 @@ func main() { EDUInternalAPI: eduInputAPI, FederationSenderAPI: fsAPI, RoomserverAPI: rsAPI, + UserAPI: userAPI, //ServerKeyAPI: serverKeyAPI, PublicRoomsDB: publicRoomsDB, } monolith.AddAllPublicRoutes(base.PublicAPIMux) - internal.SetupHTTPAPI( - http.DefaultServeMux, + httputil.SetupHTTPAPI( + base.BaseMux, base.PublicAPIMux, base.InternalAPIMux, cfg, base.UseHTTPAPIs, ) + // Build both ends of a HTTP multiplex. + httpServer := &http.Server{ + Addr: ":0", + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, + ReadTimeout: 15 * time.Second, + WriteTimeout: 45 * time.Second, + IdleTimeout: 60 * time.Second, + BaseContext: func(_ net.Listener) context.Context { + return context.Background() + }, + Handler: base.BaseMux, + } + go func() { logrus.Info("Listening on ", ygg.DerivedServerName()) logrus.Fatal(httpServer.Serve(ygg)) @@ -159,7 +176,7 @@ func main() { go func() { httpBindAddr := fmt.Sprintf(":%d", *instancePort) logrus.Info("Listening on ", httpBindAddr) - logrus.Fatal(http.ListenAndServe(httpBindAddr, nil)) + logrus.Fatal(http.ListenAndServe(httpBindAddr, base.BaseMux)) }() select {} diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/node.go b/cmd/dendrite-demo-yggdrasil/yggconn/node.go index b8b70f985..b43c41def 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/node.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/node.go @@ -15,13 +15,16 @@ package yggconn import ( + "context" "crypto/ed25519" "encoding/hex" "encoding/json" "fmt" "io/ioutil" "log" + "net" "os" + "strings" "sync" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/convert" @@ -48,6 +51,21 @@ type Node struct { incoming chan *yamux.Stream } +func (n *Node) Dialer(_, address string) (net.Conn, error) { + tokens := strings.Split(address, ":") + raw, err := hex.DecodeString(tokens[0]) + if err != nil { + return nil, fmt.Errorf("hex.DecodeString: %w", err) + } + converted := convert.Ed25519PublicKeyToCurve25519(ed25519.PublicKey(raw)) + convhex := hex.EncodeToString(converted) + return n.Dial("curve25519", convhex) +} + +func (n *Node) DialerContext(ctx context.Context, network, address string) (net.Conn, error) { + return n.Dialer(network, address) +} + // nolint:gocyclo func Setup(instanceName, instancePeer, storageDirectory string) (*Node, error) { n := &Node{ diff --git a/cmd/dendrite-edu-server/main.go b/cmd/dendrite-edu-server/main.go index a86f66c89..6704ebd09 100644 --- a/cmd/dendrite-edu-server/main.go +++ b/cmd/dendrite-edu-server/main.go @@ -17,21 +17,20 @@ import ( "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/sirupsen/logrus" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "EDUServerAPI", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "EDUServerAPI", true) defer func() { if err := base.Close(); err != nil { logrus.WithError(err).Warn("BaseDendrite close failed") } }() - deviceDB := base.CreateDeviceDB() - intAPI := eduserver.NewInternalAPI(base, cache.New(), deviceDB) + intAPI := eduserver.NewInternalAPI(base, cache.New(), base.UserAPIClient()) eduserver.AddInternalRoutes(base.InternalAPIMux, intAPI) base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer)) diff --git a/cmd/dendrite-federation-api-server/main.go b/cmd/dendrite-federation-api-server/main.go index fcdc000c4..e3bf5edc8 100644 --- a/cmd/dendrite-federation-api-server/main.go +++ b/cmd/dendrite-federation-api-server/main.go @@ -16,26 +16,24 @@ package main import ( "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "FederationAPI", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "FederationAPI", true) defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() + userAPI := base.UserAPIClient() federation := base.CreateFederationClient() serverKeyAPI := base.ServerKeyAPIClient() keyRing := serverKeyAPI.KeyRing() fsAPI := base.FederationSenderHTTPClient() rsAPI := base.RoomserverHTTPClient() - asAPI := base.AppserviceHTTPClient() federationapi.AddPublicRoutes( - base.PublicAPIMux, base.Cfg, accountDB, deviceDB, federation, keyRing, - rsAPI, asAPI, fsAPI, base.EDUServerClient(), + base.PublicAPIMux, base.Cfg, userAPI, federation, keyRing, + rsAPI, fsAPI, base.EDUServerClient(), ) base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI)) diff --git a/cmd/dendrite-federation-sender-server/main.go b/cmd/dendrite-federation-sender-server/main.go index 7b60ad055..20bc1070f 100644 --- a/cmd/dendrite-federation-sender-server/main.go +++ b/cmd/dendrite-federation-sender-server/main.go @@ -16,12 +16,12 @@ package main import ( "github.com/matrix-org/dendrite/federationsender" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "FederationSender", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "FederationSender", true) defer base.Close() // nolint: errcheck federation := base.CreateFederationClient() diff --git a/cmd/dendrite-key-server/main.go b/cmd/dendrite-key-server/main.go index 5dccc3004..b557cbd9e 100644 --- a/cmd/dendrite-key-server/main.go +++ b/cmd/dendrite-key-server/main.go @@ -15,19 +15,18 @@ package main import ( - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/keyserver" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "KeyServer", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "KeyServer", true) defer base.Close() // nolint: errcheck - accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() + userAPI := base.UserAPIClient() - keyserver.AddPublicRoutes(base.PublicAPIMux, base.Cfg, deviceDB, accountDB) + keyserver.AddPublicRoutes(base.PublicAPIMux, base.Cfg, userAPI) base.SetupAndServeHTTP(string(base.Cfg.Bind.KeyServer), string(base.Cfg.Listen.KeyServer)) diff --git a/cmd/dendrite-media-api-server/main.go b/cmd/dendrite-media-api-server/main.go index 5c65fad05..1582a33a8 100644 --- a/cmd/dendrite-media-api-server/main.go +++ b/cmd/dendrite-media-api-server/main.go @@ -15,18 +15,20 @@ package main import ( - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/mediaapi" + "github.com/matrix-org/gomatrixserverlib" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "MediaAPI", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "MediaAPI", true) defer base.Close() // nolint: errcheck - deviceDB := base.CreateDeviceDB() + userAPI := base.UserAPIClient() + client := gomatrixserverlib.NewClient() - mediaapi.AddPublicRoutes(base.PublicAPIMux, base.Cfg, deviceDB) + mediaapi.AddPublicRoutes(base.PublicAPIMux, base.Cfg, userAPI, client) base.SetupAndServeHTTP(string(base.Cfg.Bind.MediaAPI), string(base.Cfg.Listen.MediaAPI)) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 195a1ac50..339bbe699 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -17,18 +17,21 @@ package main import ( "flag" "net/http" + "os" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationsender" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/basecomponent" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/serverkeyapi" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -39,10 +42,11 @@ var ( certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") enableHTTPAPIs = flag.Bool("api", false, "Use HTTP APIs instead of short-circuiting (warning: exposes API endpoints!)") + traceInternal = os.Getenv("DENDRITE_TRACE_INTERNAL") == "1" ) func main() { - cfg := basecomponent.ParseFlags(true) + cfg := setup.ParseFlags(true) if *enableHTTPAPIs { // If the HTTP APIs are enabled then we need to update the Listen // statements in the configuration so that we know where to find @@ -56,7 +60,7 @@ func main() { cfg.Listen.ServerKeyAPI = addr } - base := basecomponent.NewBaseDendrite(cfg, "Monolith", *enableHTTPAPIs) + base := setup.NewBaseDendrite(cfg, "Monolith", *enableHTTPAPIs) defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() @@ -71,25 +75,32 @@ func main() { serverKeyAPI = base.ServerKeyAPIClient() } keyRing := serverKeyAPI.KeyRing() + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices) - rsComponent := roomserver.NewInternalAPI( + rsImpl := roomserver.NewInternalAPI( base, keyRing, federation, ) - rsAPI := rsComponent + // call functions directly on the impl unless running in HTTP mode + rsAPI := rsImpl if base.UseHTTPAPIs { - roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI) + roomserver.AddInternalRoutes(base.InternalAPIMux, rsImpl) rsAPI = base.RoomserverHTTPClient() } + if traceInternal { + rsAPI = &api.RoomserverInternalAPITrace{ + Impl: rsAPI, + } + } eduInputAPI := eduserver.NewInternalAPI( - base, cache.New(), deviceDB, + base, cache.New(), userAPI, ) if base.UseHTTPAPIs { eduserver.AddInternalRoutes(base.InternalAPIMux, eduInputAPI) eduInputAPI = base.EDUServerClient() } - asAPI := appservice.NewInternalAPI(base, accountDB, deviceDB, rsAPI) + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) if base.UseHTTPAPIs { appservice.AddInternalRoutes(base.InternalAPIMux, asAPI) asAPI = base.AppserviceHTTPClient() @@ -102,7 +113,9 @@ func main() { federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) fsAPI = base.FederationSenderHTTPClient() } - rsComponent.SetFederationSenderAPI(fsAPI) + // The underlying roomserver implementation needs to be able to call the fedsender. + // This is different to rsAPI which can be the http client which doesn't need this dependency + rsImpl.SetFederationSenderAPI(fsAPI) publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), base.Cfg.DbProperties(), cfg.Matrix.ServerName) if err != nil { @@ -113,6 +126,7 @@ func main() { Config: base.Cfg, AccountDB: accountDB, DeviceDB: deviceDB, + Client: gomatrixserverlib.NewClient(), FedClient: federation, KeyRing: keyRing, KafkaConsumer: base.KafkaConsumer, @@ -123,13 +137,14 @@ func main() { FederationSenderAPI: fsAPI, RoomserverAPI: rsAPI, ServerKeyAPI: serverKeyAPI, + UserAPI: userAPI, PublicRoomsDB: publicRoomsDB, } monolith.AddAllPublicRoutes(base.PublicAPIMux) - internal.SetupHTTPAPI( - http.DefaultServeMux, + httputil.SetupHTTPAPI( + base.BaseMux, base.PublicAPIMux, base.InternalAPIMux, cfg, @@ -140,7 +155,8 @@ func main() { go func() { serv := http.Server{ Addr: *httpBindAddr, - WriteTimeout: basecomponent.HTTPServerTimeout, + WriteTimeout: setup.HTTPServerTimeout, + Handler: base.BaseMux, } logrus.Info("Listening on ", serv.Addr) @@ -151,7 +167,8 @@ func main() { go func() { serv := http.Server{ Addr: *httpsBindAddr, - WriteTimeout: basecomponent.HTTPServerTimeout, + WriteTimeout: setup.HTTPServerTimeout, + Handler: base.BaseMux, } logrus.Info("Listening on ", serv.Addr) diff --git a/cmd/dendrite-public-rooms-api-server/main.go b/cmd/dendrite-public-rooms-api-server/main.go index ff7d7958f..23866b757 100644 --- a/cmd/dendrite-public-rooms-api-server/main.go +++ b/cmd/dendrite-public-rooms-api-server/main.go @@ -15,18 +15,18 @@ package main import ( - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/publicroomsapi" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/sirupsen/logrus" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "PublicRoomsAPI", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "PublicRoomsAPI", true) defer base.Close() // nolint: errcheck - deviceDB := base.CreateDeviceDB() + userAPI := base.UserAPIClient() rsAPI := base.RoomserverHTTPClient() @@ -34,7 +34,7 @@ func main() { if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } - publicroomsapi.AddPublicRoutes(base.PublicAPIMux, base.Cfg, base.KafkaConsumer, deviceDB, publicRoomsDB, rsAPI, nil, nil) + publicroomsapi.AddPublicRoutes(base.PublicAPIMux, base.Cfg, base.KafkaConsumer, userAPI, publicRoomsDB, rsAPI, nil, nil) base.SetupAndServeHTTP(string(base.Cfg.Bind.PublicRoomsAPI), string(base.Cfg.Listen.PublicRoomsAPI)) diff --git a/cmd/dendrite-room-server/main.go b/cmd/dendrite-room-server/main.go index a2f1e9415..627a68677 100644 --- a/cmd/dendrite-room-server/main.go +++ b/cmd/dendrite-room-server/main.go @@ -15,13 +15,13 @@ package main import ( - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/roomserver" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "RoomServerAPI", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "RoomServerAPI", true) defer base.Close() // nolint: errcheck federation := base.CreateFederationClient() diff --git a/cmd/dendrite-server-key-api-server/main.go b/cmd/dendrite-server-key-api-server/main.go index b4bfcbffb..9ffaeee31 100644 --- a/cmd/dendrite-server-key-api-server/main.go +++ b/cmd/dendrite-server-key-api-server/main.go @@ -15,13 +15,13 @@ package main import ( - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/serverkeyapi" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "ServerKeyAPI", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "ServerKeyAPI", true) defer base.Close() // nolint: errcheck federation := base.CreateFederationClient() diff --git a/cmd/dendrite-sync-api-server/main.go b/cmd/dendrite-sync-api-server/main.go index a5302f74e..d67395fb3 100644 --- a/cmd/dendrite-sync-api-server/main.go +++ b/cmd/dendrite-sync-api-server/main.go @@ -15,22 +15,21 @@ package main import ( - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/syncapi" ) func main() { - cfg := basecomponent.ParseFlags(false) - base := basecomponent.NewBaseDendrite(cfg, "SyncAPI", true) + cfg := setup.ParseFlags(false) + base := setup.NewBaseDendrite(cfg, "SyncAPI", true) defer base.Close() // nolint: errcheck - deviceDB := base.CreateDeviceDB() - accountDB := base.CreateAccountsDB() + userAPI := base.UserAPIClient() federation := base.CreateFederationClient() rsAPI := base.RoomserverHTTPClient() - syncapi.AddPublicRoutes(base.PublicAPIMux, base.KafkaConsumer, deviceDB, accountDB, rsAPI, federation, cfg) + syncapi.AddPublicRoutes(base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, federation, cfg) base.SetupAndServeHTTP(string(base.Cfg.Bind.SyncAPI), string(base.Cfg.Listen.SyncAPI)) diff --git a/cmd/dendritejs/jsServer.go b/cmd/dendritejs/jsServer.go index 31f122645..074d20cba 100644 --- a/cmd/dendritejs/jsServer.go +++ b/cmd/dendritejs/jsServer.go @@ -28,7 +28,7 @@ import ( // JSServer exposes an HTTP-like server interface which allows JS to 'send' requests to it. type JSServer struct { // The router which will service requests - Mux *http.ServeMux + Mux http.Handler } // OnRequestFromJS is the function that JS will invoke when there is a new request. diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 0512ee5c9..883b0fad0 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -19,19 +19,18 @@ package main import ( "crypto/ed25519" "fmt" - "net/http" "syscall/js" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationsender" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/basecomponent" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/userapi" go_http_js_libp2p "github.com/matrix-org/go-http-js-libp2p" "github.com/matrix-org/gomatrixserverlib" @@ -146,6 +145,11 @@ func createFederationClient(cfg *config.Dendrite, node *go_http_js_libp2p.P2pLoc return fed } +func createClient(node *go_http_js_libp2p.P2pLocalNode) *gomatrixserverlib.Client { + tr := go_http_js_libp2p.NewP2pTransport(node) + return gomatrixserverlib.NewClientWithTransport(tr) +} + func createP2PNode(privKey ed25519.PrivateKey) (serverName string, node *go_http_js_libp2p.P2pLocalNode) { hosted := "/dns4/rendezvous.matrix.org/tcp/8443/wss/p2p-websocket-star/" node = go_http_js_libp2p.NewP2pLocalNode("org.matrix.p2p.experiment", privKey.Seed(), []string{hosted}, "p2p") @@ -184,12 +188,13 @@ func main() { if err := cfg.Derive(); err != nil { logrus.Fatalf("Failed to derive values from config: %s", err) } - base := basecomponent.NewBaseDendrite(cfg, "Monolith", false) + base := setup.NewBaseDendrite(cfg, "Monolith", false) defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() federation := createFederationClient(cfg, node) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil) fetcher := &libp2pKeyFetcher{} keyRing := gomatrixserverlib.KeyRing{ @@ -200,9 +205,9 @@ func main() { } rsAPI := roomserver.NewInternalAPI(base, keyRing, federation) - eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), deviceDB) + eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) asQuery := appservice.NewInternalAPI( - base, accountDB, deviceDB, rsAPI, + base, userAPI, rsAPI, ) fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, &keyRing) rsAPI.SetFederationSenderAPI(fedSenderAPI) @@ -217,6 +222,7 @@ func main() { Config: base.Cfg, AccountDB: accountDB, DeviceDB: deviceDB, + Client: createClient(node), FedClient: federation, KeyRing: &keyRing, KafkaConsumer: base.KafkaConsumer, @@ -226,6 +232,7 @@ func main() { EDUInternalAPI: eduInputAPI, FederationSenderAPI: fedSenderAPI, RoomserverAPI: rsAPI, + UserAPI: userAPI, //ServerKeyAPI: serverKeyAPI, PublicRoomsDB: publicRoomsDB, @@ -233,8 +240,8 @@ func main() { } monolith.AddAllPublicRoutes(base.PublicAPIMux) - internal.SetupHTTPAPI( - http.DefaultServeMux, + httputil.SetupHTTPAPI( + base.BaseMux, base.PublicAPIMux, base.InternalAPIMux, cfg, @@ -246,7 +253,7 @@ func main() { go func() { logrus.Info("Listening on libp2p-js host ID ", node.Id) s := JSServer{ - Mux: http.DefaultServeMux, + Mux: base.BaseMux, } s.ListenAndServe("p2p") }() @@ -256,7 +263,7 @@ func main() { go func() { logrus.Info("Listening for service-worker fetch traffic") s := JSServer{ - Mux: http.DefaultServeMux, + Mux: base.BaseMux, } s.ListenAndServe("fetch") }() diff --git a/cmd/roomserver-integration-tests/main.go b/cmd/roomserver-integration-tests/main.go index 43aca0789..3860ca1f7 100644 --- a/cmd/roomserver-integration-tests/main.go +++ b/cmd/roomserver-integration-tests/main.go @@ -255,7 +255,7 @@ func testRoomserver(input []string, wantOutput []string, checkQueries func(api.R panic(err) } - cache, err := caching.NewInMemoryLRUCache() + cache, err := caching.NewInMemoryLRUCache(false) if err != nil { panic(err) } diff --git a/dendrite-config.yaml b/dendrite-config.yaml index a5b295977..52793cda4 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -140,6 +140,7 @@ listen: edu_server: "localhost:7778" key_server: "localhost:7779" server_key_api: "localhost:7780" + user_api: "localhost:7781" # The configuration for tracing the dendrite components. tracing: diff --git a/docs/WIRING-Current.md b/docs/WIRING-Current.md new file mode 100644 index 000000000..ec539d4e9 --- /dev/null +++ b/docs/WIRING-Current.md @@ -0,0 +1,73 @@ +This document details how various components communicate with each other. There are two kinds of components: + - Public-facing: exposes CS/SS API endpoints and need to be routed to via client-api-proxy or equivalent. + - Internal-only: exposes internal APIs and produces Kafka events. + +## Internal HTTP APIs + +Not everything can be done using Kafka logs. For example, requesting the latest events in a room is much better suited to +a request/response model like HTTP or RPC. Therefore, components can expose "internal APIs" which sit outside of Kafka logs. +Note in Monolith mode these are actually direct function calls and are not serialised HTTP requests. + +``` + Tier 1 Sync PublicRooms FederationAPI ClientAPI MediaAPI +Public Facing | .-----1------` | | | | | | | | | + 2 | .-------3-----------------` | | | `--------|-|-|-|--11--------------------. + | | | .--------4----------------------------------` | | | | + | | | | .---5-----------` | | | | | | + | | | | | .---6----------------------------` | | | + | | | | | | | .-----7----------` | | + | | | | | | 8 | | 10 | + | | | | | | | | `---9----. | | + V V V V V V V V V V V + Tier 2 Roomserver EDUServer FedSender AppService KeyServer ServerKeyAPI +Internal only | `------------------------12----------^ ^ + `------------------------------------------------------------13----------` + + Client ---> Server +``` +- 1 (PublicRooms -> Roomserver): Calculating current auth for changing visibility +- 2 (Sync -> Roomserver): When making backfill requests +- 3 (FedAPI -> Roomserver): Calculating (prev/auth events) and sending new events, processing backfill/state/state_ids requests +- 4 (ClientAPI -> Roomserver): Calculating (prev/auth events) and sending new events, processing /state requests +- 5 (FedAPI -> EDUServer): Sending typing/send-to-device events +- 6 (ClientAPI -> EDUServer): Sending typing/send-to-device events +- 7 (ClientAPI -> FedSender): Handling directory lookups +- 8 (FedAPI -> FedSender): Resetting backoffs when receiving traffic from a server. Querying joined hosts when handling alias lookup requests +- 9 (FedAPI -> AppService): Working out if the client is an appservice user +- 10 (ClientAPI -> AppService): Working out if the client is an appservice user +- 11 (FedAPI -> ServerKeyAPI): Verifying incoming event signatures +- 12 (FedSender -> ServerKeyAPI): Verifying event signatures of responses (e.g from send_join) +- 13 (Roomserver -> ServerKeyAPI): Verifying event signatures of backfilled events + +In addition to this, all public facing components (Tier 1) talk to the `UserAPI` to verify access tokens and extract profile information where needed. + +## Kafka logs + +``` + .----1--------------------------------------------. + V | + Tier 1 Sync PublicRooms FederationAPI ClientAPI MediaAPI +Public Facing ^ ^ ^ ^ + | | | | + 2 | | | + | `-3------------. | + | | | | + | | | | + | .------4------` | | + | | .--------5-----|------------------------------` + | | | | + Tier 2 Roomserver EDUServer FedSender AppService KeyServer ServerKeyAPI +Internal only | | ^ ^ + | `-----6----------` | + `--------------------7--------` + + +Producer ----> Consumer +``` +- 1 (ClientAPI -> Sync): For tracking account data +- 2 (Roomserver -> Sync): For all data to send to clients +- 3 (EDUServer -> Sync): For typing/send-to-device data to send to clients +- 4 (Roomserver -> PublicRooms): For tracking the current room name/topic/joined count/etc. +- 5 (Roomserver -> ClientAPI): For tracking memberships for profile updates. +- 6 (EDUServer -> FedSender): For sending EDUs over federation +- 7 (Roomserver -> FedSender): For sending PDUs over federation, for tracking joined hosts. diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index c14be3f67..2e6ba0c85 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -18,12 +18,12 @@ package eduserver import ( "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/input" "github.com/matrix-org/dendrite/eduserver/inthttp" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions @@ -35,13 +35,13 @@ func AddInternalRoutes(internalMux *mux.Router, inputAPI api.EDUServerInputAPI) // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *basecomponent.BaseDendrite, + base *setup.BaseDendrite, eduCache *cache.EDUCache, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) api.EDUServerInputAPI { return &input.EDUServerInputAPI{ Cache: eduCache, - DeviceDB: deviceDB, + UserAPI: userAPI, Producer: base.KafkaProducer, OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent), OutputSendToDeviceEventTopic: string(base.Cfg.Kafka.Topics.OutputSendToDeviceEvent), diff --git a/eduserver/input/input.go b/eduserver/input/input.go index 0bbf5b844..e3d2c55e3 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -22,9 +22,9 @@ import ( "time" "github.com/Shopify/sarama" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -39,8 +39,8 @@ type EDUServerInputAPI struct { OutputSendToDeviceEventTopic string // kafka producer Producer sarama.SyncProducer - // device database - DeviceDB devices.Database + // Internal user query API + UserAPI userapi.UserInternalAPI // our server name ServerName gomatrixserverlib.ServerName } @@ -97,6 +97,11 @@ func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error { if err != nil { return err } + logrus.WithFields(logrus.Fields{ + "room_id": ite.RoomID, + "user_id": ite.UserID, + "typing": ite.Typing, + }).Infof("Producing to topic '%s'", t.OutputTypingEventTopic) m := &sarama.ProducerMessage{ Topic: string(t.OutputTypingEventTopic), @@ -110,7 +115,7 @@ func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error { func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) error { devices := []string{} - localpart, domain, err := gomatrixserverlib.SplitID('@', ise.UserID) + _, domain, err := gomatrixserverlib.SplitID('@', ise.UserID) if err != nil { return err } @@ -121,17 +126,25 @@ func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) e // wildcard as we don't know about the remote devices, so instead we leave it // as-is, so that the federation sender can send it on with the wildcard intact. if domain == t.ServerName && ise.DeviceID == "*" { - devs, err := t.DeviceDB.GetDevicesByLocalpart(context.TODO(), localpart) + var res userapi.QueryDevicesResponse + err = t.UserAPI.QueryDevices(context.TODO(), &userapi.QueryDevicesRequest{ + UserID: ise.UserID, + }, &res) if err != nil { return err } - for _, dev := range devs { + for _, dev := range res.Devices { devices = append(devices, dev.ID) } } else { devices = append(devices, ise.DeviceID) } + logrus.WithFields(logrus.Fields{ + "user_id": ise.UserID, + "num_devices": len(devices), + "type": ise.Type, + }).Infof("Producing to topic '%s'", t.OutputSendToDeviceEventTopic) for _, device := range devices { ote := &api.OutputSendToDeviceEvent{ UserID: ise.UserID, @@ -139,12 +152,6 @@ func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) e SendToDeviceEvent: ise.SendToDeviceEvent, } - logrus.WithFields(logrus.Fields{ - "user_id": ise.UserID, - "device_id": ise.DeviceID, - "event_type": ise.Type, - }).Info("handling send-to-device message") - eventJSON, err := json.Marshal(ote) if err != nil { logrus.WithError(err).Error("sendToDevice failed json.Marshal") diff --git a/eduserver/inthttp/client.go b/eduserver/inthttp/client.go index 149e4fb04..7d0bc1603 100644 --- a/eduserver/inthttp/client.go +++ b/eduserver/inthttp/client.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/matrix-org/dendrite/eduserver/api" - internalHTTP "github.com/matrix-org/dendrite/internal/http" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/opentracing/opentracing-go" ) @@ -39,7 +39,7 @@ func (h *httpEDUServerInputAPI) InputTypingEvent( defer span.Finish() apiURL := h.eduServerURL + EDUServerInputTypingEventPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // InputSendToDeviceEvent implements EDUServerInputAPI @@ -52,5 +52,5 @@ func (h *httpEDUServerInputAPI) InputSendToDeviceEvent( defer span.Finish() apiURL := h.eduServerURL + EDUServerInputSendToDeviceEventPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } diff --git a/eduserver/inthttp/server.go b/eduserver/inthttp/server.go index 6c7d21a4e..e374513a3 100644 --- a/eduserver/inthttp/server.go +++ b/eduserver/inthttp/server.go @@ -6,14 +6,14 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/eduserver/api" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/util" ) // AddRoutes adds the EDUServerInputAPI handlers to the http.ServeMux. func AddRoutes(t api.EDUServerInputAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle(EDUServerInputTypingEventPath, - internal.MakeInternalAPI("inputTypingEvents", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("inputTypingEvents", func(req *http.Request) util.JSONResponse { var request api.InputTypingEventRequest var response api.InputTypingEventResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -26,7 +26,7 @@ func AddRoutes(t api.EDUServerInputAPI, internalAPIMux *mux.Router) { }), ) internalAPIMux.Handle(EDUServerInputSendToDeviceEventPath, - internal.MakeInternalAPI("inputSendToDeviceEvents", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("inputSendToDeviceEvents", func(req *http.Request) util.JSONResponse { var request api.InputSendToDeviceEventRequest var response api.InputSendToDeviceEventResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 9299b5016..c0c000434 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -16,13 +16,11 @@ package federationapi import ( "github.com/gorilla/mux" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/federationapi/routing" "github.com/matrix-org/gomatrixserverlib" @@ -32,19 +30,17 @@ import ( func AddPublicRoutes( router *mux.Router, cfg *config.Dendrite, - accountsDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, - keyRing *gomatrixserverlib.KeyRing, + keyRing gomatrixserverlib.JSONVerifier, rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, federationSenderAPI federationSenderAPI.FederationSenderInternalAPI, eduAPI eduserverAPI.EDUServerInputAPI, ) { routing.Setup( - router, cfg, rsAPI, asAPI, - eduAPI, federationSenderAPI, *keyRing, - federation, accountsDB, deviceDB, + router, cfg, rsAPI, + eduAPI, federationSenderAPI, keyRing, + federation, userAPI, ) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go new file mode 100644 index 000000000..cc85c61bf --- /dev/null +++ b/federationapi/federationapi_test.go @@ -0,0 +1,102 @@ +package federationapi_test + +import ( + "context" + "crypto/ed25519" + "strings" + "testing" + + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/setup" + "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" +) + +// Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404. +// Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated. +func TestRoomsV3URLEscapeDoNot404(t *testing.T) { + _, privKey, _ := ed25519.GenerateKey(nil) + cfg := &config.Dendrite{} + cfg.Matrix.KeyID = gomatrixserverlib.KeyID("ed25519:auto") + cfg.Matrix.ServerName = gomatrixserverlib.ServerName("localhost") + cfg.Matrix.PrivateKey = privKey + cfg.Kafka.UseNaffka = true + cfg.Database.Naffka = "file::memory:" + cfg.SetDefaults() + base := setup.NewBaseDendrite(cfg, "Test", false) + keyRing := &test.NopJSONVerifier{} + fsAPI := base.FederationSenderHTTPClient() + // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. + // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. + federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil) + httputil.SetupHTTPAPI( + base.BaseMux, + base.PublicAPIMux, + base.InternalAPIMux, + cfg, + base.UseHTTPAPIs, + ) + baseURL, cancel := test.ListenAndServe(t, base.BaseMux, true) + defer cancel() + serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) + + fedCli := gomatrixserverlib.NewFederationClient(serverName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) + + testCases := []struct { + roomVer gomatrixserverlib.RoomVersion + eventJSON string + }{ + { + eventJSON: `{"auth_events":[["$Nzfbrhc3oaYVKzGM:localhost",{"sha256":"BCBHOgB4qxLPQkBd6th8ydFSyqjth/LF99VNjYffOQ0"}],["$EZzkD2BH1Gtm5v1D:localhost",{"sha256":"3dLUnDBs8/iC5DMw/ydKtmAqVZtzqqtHpsjsQPk7GJA"}]],"content":{"body":"Test Message"},"depth":11,"event_id":"$mGiPO3oGjQfCkIUw:localhost","hashes":{"sha256":"h+t+4DwIBC9UNyJ3jzyAQAAl4H3yQHVuHrm2S1JZizU"},"origin":"localhost","origin_server_ts":0,"prev_events":[["$tFr64vpiSHdLU0Qr:localhost",{"sha256":"+R07ZrIs4c4tjPFE+tmcYIGUfeLGFI/4e0OITb9uEcM"}]],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"LYFr/rW9m5/7UKBQMF5qWnG82He4VGsRESUgDmvkn5DrJRyS4TLL/7zl0Lymn3pa3q2yaTO74LQX/CRotqG1BA"}},"type":"m.room.message"}`, + roomVer: gomatrixserverlib.RoomVersionV1, + }, + // single / (handlers which do not UseEncodedPath will fail this test) + // EventID: $0SFh2WJbjBs3OT+E0yl95giDKo/3Zp52HsHUUk4uPyg + { + eventJSON: `{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":8,"hashes":{"sha256":"dfK0MBn1RZZqCVJqWsn/MGY7QJHjQcwqF0unOonLCTU"},"origin":"localhost","origin_server_ts":0,"prev_events":["$1SwcZ1XY/Y8yKLjP4DzAOHN5WFBcDAZxb5vFDnW2ubA"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"INOjuWMg+GmFkUpmzhMB0bqLNs73mSvwldY1ftYIQ/B3lD9soD2OMG3AF+wgZW/I8xqzY4DOHfbnbUeYPf67BA"}},"type":"m.room.message"}`, + roomVer: gomatrixserverlib.RoomVersionV3, + }, + // multiple / + // EventID: $OzENBCuVv/fnRAYCeQudIon/84/V5pxtEjQMTgi3emk + { + eventJSON: `{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":2,"hashes":{"sha256":"U5+WsiJAhiEM88J8HTjuUjPImVGVzDFD3v/WS+jb2f0"},"origin":"localhost","origin_server_ts":0,"prev_events":["$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"tKS469e9+wdWPEKB/LbBJWQ8vfOOdKgTWER5IwbSAH1CxmLvkCziUsgVu85zfzDSLoUi5mU5FHLiMTC6P/qICw"}},"type":"m.room.message"}`, + roomVer: gomatrixserverlib.RoomVersionV3, + }, + // two slashes (handlers which clean paths before UseEncodedPath will fail this test) + // EventID: $EmwNBlHoSOVmCZ1cM//yv/OvxB6r4OFEIGSJea7+Amk + { + eventJSON: `{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`, + roomVer: gomatrixserverlib.RoomVersionV3, + }, + } + + for _, tc := range testCases { + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(tc.eventJSON), false, tc.roomVer) + if err != nil { + t.Errorf("failed to parse event: %s", err) + } + he := ev.Headered(tc.roomVer) + invReq, err := gomatrixserverlib.NewInviteV2Request(&he, nil) + if err != nil { + t.Errorf("failed to create invite v2 request: %s", err) + continue + } + _, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) + if err == nil { + t.Errorf("expected an error, got none") + continue + } + gerr, ok := err.(gomatrix.HTTPError) + if !ok { + t.Errorf("failed to cast response error as gomatrix.HTTPError") + continue + } + t.Logf("Error: %+v", gerr) + if gerr.Code == 404 { + t.Errorf("invite event resulted in a 404") + } + } +} diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 10bc62630..f906c73c9 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -37,7 +37,7 @@ func Backfill( roomID string, cfg *config.Dendrite, ) util.JSONResponse { - var res api.QueryBackfillResponse + var res api.PerformBackfillResponse var eIDs []string var limit string var exists bool @@ -68,7 +68,7 @@ func Backfill( } // Populate the request. - req := api.QueryBackfillRequest{ + req := api.PerformBackfillRequest{ RoomID: roomID, // we don't know who the successors are for these events, which won't // be a problem because we don't use that information when servicing /backfill requests, @@ -87,8 +87,8 @@ func Backfill( } // Query the roomserver. - if err = rsAPI.QueryBackfill(httpReq.Context(), &req, &res); err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryBackfill failed") + if err = rsAPI.PerformBackfill(httpReq.Context(), &req, &res); err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("query.PerformBackfill failed") return jsonerror.InternalServerError() } diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index caf5fe592..6369c708c 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -15,9 +15,8 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/clientapi/userutil" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -25,17 +24,9 @@ import ( // GetUserDevices for the given user id func GetUserDevices( req *http.Request, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, userID string, ) util.JSONResponse { - localpart, err := userutil.ParseUsernameParam(userID, nil) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("Invalid user ID"), - } - } - response := gomatrixserverlib.RespUserDevices{ UserID: userID, // TODO: we should return an incrementing stream ID each time the device @@ -43,13 +34,16 @@ func GetUserDevices( StreamID: 0, } - devs, err := deviceDB.GetDevicesByLocalpart(req.Context(), localpart) + var res userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{ + UserID: userID, + }, &res) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalPart failed") + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryDevices failed") return jsonerror.InternalServerError() } - for _, dev := range devs { + for _, dev := range res.Devices { device := gomatrixserverlib.RespUserDevice{ DeviceID: dev.ID, DisplayName: dev.DisplayName, diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 908a04fcb..7d02bc1d3 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -35,7 +35,7 @@ func Invite( eventID string, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, - keys gomatrixserverlib.KeyRing, + keys gomatrixserverlib.JSONVerifier, ) util.JSONResponse { inviteReq := gomatrixserverlib.InviteV2Request{} if err := json.Unmarshal(request.Content(), &inviteReq); err != nil { diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 593aa169f..8dcd15333 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -21,8 +21,8 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -95,8 +95,8 @@ func MakeJoin( queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: verRes.RoomVersion, } - event, err := internal.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), rsAPI, &queryRes) - if err == internal.ErrRoomNoExists { + event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), rsAPI, &queryRes) + if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Room does not exist"), @@ -107,7 +107,7 @@ func MakeJoin( JSON: jsonerror.BadJSON(e.Error()), } } else if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("internal.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") return jsonerror.InternalServerError() } @@ -143,7 +143,7 @@ func SendJoin( request *gomatrixserverlib.FederationRequest, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, - keys gomatrixserverlib.KeyRing, + keys gomatrixserverlib.JSONVerifier, roomID, eventID string, ) util.JSONResponse { verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index f998be45a..108fc50ae 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -17,8 +17,8 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -69,8 +69,8 @@ func MakeLeave( } var queryRes api.QueryLatestEventsAndStateResponse - event, err := internal.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), rsAPI, &queryRes) - if err == internal.ErrRoomNoExists { + event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg, time.Now(), rsAPI, &queryRes) + if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Room does not exist"), @@ -81,7 +81,7 @@ func MakeLeave( JSON: jsonerror.BadJSON(e.Error()), } } else if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("internal.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") return jsonerror.InternalServerError() } @@ -113,7 +113,7 @@ func SendLeave( request *gomatrixserverlib.FederationRequest, cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, - keys gomatrixserverlib.KeyRing, + keys gomatrixserverlib.JSONVerifier, roomID, eventID string, ) util.JSONResponse { verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index 832849845..a6180ae6d 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -18,11 +18,10 @@ import ( "fmt" "net/http" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -30,9 +29,8 @@ import ( // GetProfile implements GET /_matrix/federation/v1/query/profile func GetProfile( httpReq *http.Request, - accountDB accounts.Database, + userAPI userapi.UserInternalAPI, cfg *config.Dendrite, - asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { userID, field := httpReq.FormValue("user_id"), httpReq.FormValue("field") @@ -60,9 +58,12 @@ func GetProfile( } } - profile, err := appserviceAPI.RetrieveUserProfile(httpReq.Context(), userID, asAPI, accountDB) + var profileRes userapi.QueryProfileResponse + err = userAPI.QueryProfile(httpReq.Context(), &userapi.QueryProfileRequest{ + UserID: userID, + }, &profileRes) if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") + util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryProfile failed") return jsonerror.InternalServerError() } @@ -72,21 +73,21 @@ func GetProfile( if field != "" { switch field { case "displayname": - res = internal.DisplayName{ - DisplayName: profile.DisplayName, + res = eventutil.DisplayName{ + DisplayName: profileRes.DisplayName, } case "avatar_url": - res = internal.AvatarURL{ - AvatarURL: profile.AvatarURL, + res = eventutil.AvatarURL{ + AvatarURL: profileRes.AvatarURL, } default: code = http.StatusBadRequest res = jsonerror.InvalidArgumentValue("The request body did not contain an allowed value of argument 'field'. Allowed values are either: 'avatar_url', 'displayname'.") } } else { - res = internal.ProfileResponse{ - AvatarURL: profile.AvatarURL, - DisplayName: profile.DisplayName, + res = eventutil.ProfileResponse{ + AvatarURL: profileRes.AvatarURL, + DisplayName: profileRes.DisplayName, } } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 754dcdfb0..645f397de 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -18,14 +18,13 @@ import ( "net/http" "github.com/gorilla/mux" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" + "github.com/matrix-org/dendrite/clientapi/jsonerror" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -48,23 +47,21 @@ func Setup( publicAPIMux *mux.Router, cfg *config.Dendrite, rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, eduAPI eduserverAPI.EDUServerInputAPI, fsAPI federationSenderAPI.FederationSenderInternalAPI, - keys gomatrixserverlib.KeyRing, + keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) { v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter() v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter() v2fedmux := publicAPIMux.PathPrefix(pathPrefixV2Federation).Subrouter() - wakeup := &internal.FederationWakeups{ + wakeup := &httputil.FederationWakeups{ FsAPI: fsAPI, } - localKeys := internal.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { + localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { return LocalKeys(cfg) }) @@ -76,7 +73,7 @@ func Setup( v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet) - v1fedmux.Handle("/send/{txnID}", internal.MakeFedAPI( + v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI( "federation_send", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( @@ -86,7 +83,7 @@ func Setup( }, )).Methods(http.MethodPut, http.MethodOptions) - v2fedmux.Handle("/invite/{roomID}/{eventID}", internal.MakeFedAPI( + v2fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Invite( @@ -96,13 +93,13 @@ func Setup( }, )).Methods(http.MethodPut, http.MethodOptions) - v1fedmux.Handle("/3pid/onbind", internal.MakeExternalAPI("3pid_onbind", + v1fedmux.Handle("/3pid/onbind", httputil.MakeExternalAPI("3pid_onbind", func(req *http.Request) util.JSONResponse { - return CreateInvitesFrom3PIDInvites(req, rsAPI, asAPI, cfg, federation, accountDB) + return CreateInvitesFrom3PIDInvites(req, rsAPI, cfg, federation, userAPI) }, )).Methods(http.MethodPost, http.MethodOptions) - v1fedmux.Handle("/exchange_third_party_invite/{roomID}", internal.MakeFedAPI( + v1fedmux.Handle("/exchange_third_party_invite/{roomID}", httputil.MakeFedAPI( "exchange_third_party_invite", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ExchangeThirdPartyInvite( @@ -111,7 +108,7 @@ func Setup( }, )).Methods(http.MethodPut, http.MethodOptions) - v1fedmux.Handle("/event/{eventID}", internal.MakeFedAPI( + v1fedmux.Handle("/event/{eventID}", httputil.MakeFedAPI( "federation_get_event", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetEvent( @@ -120,7 +117,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/state/{roomID}", internal.MakeFedAPI( + v1fedmux.Handle("/state/{roomID}", httputil.MakeFedAPI( "federation_get_state", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetState( @@ -129,7 +126,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/state_ids/{roomID}", internal.MakeFedAPI( + v1fedmux.Handle("/state_ids/{roomID}", httputil.MakeFedAPI( "federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetStateIDs( @@ -138,7 +135,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/event_auth/{roomID}/{eventID}", internal.MakeFedAPI( + v1fedmux.Handle("/event_auth/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetEventAuth( @@ -147,7 +144,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/query/directory", internal.MakeFedAPI( + v1fedmux.Handle("/query/directory", httputil.MakeFedAPI( "federation_query_room_alias", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return RoomAliasToID( @@ -156,25 +153,25 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/query/profile", internal.MakeFedAPI( + v1fedmux.Handle("/query/profile", httputil.MakeFedAPI( "federation_query_profile", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetProfile( - httpReq, accountDB, cfg, asAPI, + httpReq, userAPI, cfg, ) }, )).Methods(http.MethodGet) - v1fedmux.Handle("/user/devices/{userID}", internal.MakeFedAPI( + v1fedmux.Handle("/user/devices/{userID}", httputil.MakeFedAPI( "federation_user_devices", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( - httpReq, deviceDB, vars["userID"], + httpReq, userAPI, vars["userID"], ) }, )).Methods(http.MethodGet) - v1fedmux.Handle("/make_join/{roomID}/{eventID}", internal.MakeFedAPI( + v1fedmux.Handle("/make_join/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_make_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { roomID := vars["roomID"] @@ -199,7 +196,7 @@ func Setup( }, )).Methods(http.MethodGet) - v1fedmux.Handle("/send_join/{roomID}/{eventID}", internal.MakeFedAPI( + v1fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { roomID := vars["roomID"] @@ -207,17 +204,25 @@ func Setup( res := SendJoin( httpReq, request, cfg, rsAPI, keys, roomID, eventID, ) + // not all responses get wrapped in [code, body] + var body interface{} + body = []interface{}{ + res.Code, res.JSON, + } + jerr, ok := res.JSON.(*jsonerror.MatrixError) + if ok { + body = jerr + } + return util.JSONResponse{ Headers: res.Headers, Code: res.Code, - JSON: []interface{}{ - res.Code, res.JSON, - }, + JSON: body, } }, )).Methods(http.MethodPut) - v2fedmux.Handle("/send_join/{roomID}/{eventID}", internal.MakeFedAPI( + v2fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { roomID := vars["roomID"] @@ -228,7 +233,7 @@ func Setup( }, )).Methods(http.MethodPut) - v1fedmux.Handle("/make_leave/{roomID}/{eventID}", internal.MakeFedAPI( + v1fedmux.Handle("/make_leave/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { roomID := vars["roomID"] @@ -239,7 +244,7 @@ func Setup( }, )).Methods(http.MethodGet) - v2fedmux.Handle("/send_leave/{roomID}/{eventID}", internal.MakeFedAPI( + v2fedmux.Handle("/send_leave/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { roomID := vars["roomID"] @@ -250,21 +255,21 @@ func Setup( }, )).Methods(http.MethodPut) - v1fedmux.Handle("/version", internal.MakeExternalAPI( + v1fedmux.Handle("/version", httputil.MakeExternalAPI( "federation_version", func(httpReq *http.Request) util.JSONResponse { return Version() }, )).Methods(http.MethodGet) - v1fedmux.Handle("/get_missing_events/{roomID}", internal.MakeFedAPI( + v1fedmux.Handle("/get_missing_events/{roomID}", httputil.MakeFedAPI( "federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetMissingEvents(httpReq, request, rsAPI, vars["roomID"]) }, )).Methods(http.MethodPost) - v1fedmux.Handle("/backfill/{roomID}", internal.MakeFedAPI( + v1fedmux.Handle("/backfill/{roomID}", httputil.MakeFedAPI( "federation_backfill", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Backfill(httpReq, request, rsAPI, vars["roomID"], cfg) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index a057750f6..cf71b8ba0 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -37,7 +37,7 @@ func Send( cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, eduAPI eduserverAPI.EDUServerInputAPI, - keys gomatrixserverlib.KeyRing, + keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { t := txnReq{ diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index b81f1c003..6e6606c86 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -10,6 +10,7 @@ import ( eduAPI "github.com/matrix-org/dendrite/eduserver/api" fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -53,15 +54,6 @@ func init() { } } -type testNopJSONVerifier struct { - // this verifier verifies nothing -} - -func (t *testNopJSONVerifier) VerifyJSONs(ctx context.Context, requests []gomatrixserverlib.VerifyJSONRequest) ([]gomatrixserverlib.VerifyJSONResult, error) { - result := make([]gomatrixserverlib.VerifyJSONResult, len(requests)) - return result, nil -} - type testEDUProducer struct { // this producer keeps track of calls to InputTypingEvent invocations []eduAPI.InputTypingEventRequest @@ -128,7 +120,6 @@ func (t *testRoomserverAPI) QueryLatestEventsAndState( response *api.QueryLatestEventsAndStateResponse, ) error { r := t.queryLatestEventsAndState(request) - response.QueryLatestEventsAndStateRequest = *request response.RoomExists = r.RoomExists response.RoomVersion = testRoomVersion response.LatestEvents = r.LatestEvents @@ -144,7 +135,6 @@ func (t *testRoomserverAPI) QueryStateAfterEvents( response *api.QueryStateAfterEventsResponse, ) error { response.RoomVersion = testRoomVersion - response.QueryStateAfterEventsRequest = *request res := t.queryStateAfterEvents(request) response.PrevEventsExist = res.PrevEventsExist response.RoomExists = res.RoomExists @@ -181,15 +171,6 @@ func (t *testRoomserverAPI) QueryMembershipsForRoom( return fmt.Errorf("not implemented") } -// Query a list of invite event senders for a user in a room. -func (t *testRoomserverAPI) QueryInvitesForUser( - ctx context.Context, - request *api.QueryInvitesForUserRequest, - response *api.QueryInvitesForUserResponse, -) error { - return fmt.Errorf("not implemented") -} - // Query whether a server is allowed to see an event func (t *testRoomserverAPI) QueryServerAllowedToSeeEvent( ctx context.Context, @@ -220,10 +201,10 @@ func (t *testRoomserverAPI) QueryStateAndAuthChain( } // Query a given amount (or less) of events prior to a given set of events. -func (t *testRoomserverAPI) QueryBackfill( +func (t *testRoomserverAPI) PerformBackfill( ctx context.Context, - request *api.QueryBackfillRequest, - response *api.QueryBackfillResponse, + request *api.PerformBackfillRequest, + response *api.PerformBackfillResponse, ) error { return fmt.Errorf("not implemented") } @@ -341,7 +322,7 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat context: context.Background(), rsAPI: rsAPI, eduAPI: &testEDUProducer{}, - keys: &testNopJSONVerifier{}, + keys: &test.NopJSONVerifier{}, federation: fedClient, haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), newEvents: make(map[string]bool), @@ -621,7 +602,6 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { } } } - res.QueryEventsByIDRequest = *req return res }, } diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index f90c494c3..04a18904b 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -107,15 +107,13 @@ func getState( return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} } - authEventIDs := getIDsFromEventRef(event.AuthEvents()) - var response api.QueryStateAndAuthChainResponse err := rsAPI.QueryStateAndAuthChain( ctx, &api.QueryStateAndAuthChainRequest{ RoomID: roomID, PrevEventIDs: []string{eventID}, - AuthEventIDs: authEventIDs, + AuthEventIDs: event.AuthEventIDs(), }, &response, ) @@ -134,15 +132,6 @@ func getState( }, nil } -func getIDsFromEventRef(events []gomatrixserverlib.EventReference) []string { - IDs := make([]string, len(events)) - for i := range events { - IDs[i] = events[i].EventID - } - - return IDs -} - func getIDsFromEvent(events []gomatrixserverlib.Event) []string { IDs := make([]string, len(events)) for i := range events { diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 8f3193870..61788010b 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -21,13 +21,11 @@ import ( "net/http" "time" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" - roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -57,10 +55,10 @@ var ( // CreateInvitesFrom3PIDInvites implements POST /_matrix/federation/v1/3pid/onbind func CreateInvitesFrom3PIDInvites( - req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite, + req *http.Request, rsAPI api.RoomserverInternalAPI, + cfg *config.Dendrite, federation *gomatrixserverlib.FederationClient, - accountDB accounts.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { var body invites if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { @@ -79,7 +77,7 @@ func CreateInvitesFrom3PIDInvites( } event, err := createInviteFrom3PIDInvite( - req.Context(), rsAPI, asAPI, cfg, inv, federation, accountDB, + req.Context(), rsAPI, cfg, inv, federation, userAPI, ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("createInviteFrom3PIDInvite failed") @@ -107,7 +105,7 @@ func ExchangeThirdPartyInvite( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, roomID string, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI api.RoomserverInternalAPI, cfg *config.Dendrite, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -197,10 +195,10 @@ func ExchangeThirdPartyInvite( // Returns an error if there was a problem building the event or fetching the // necessary data to do so. func createInviteFrom3PIDInvite( - ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite, + ctx context.Context, rsAPI api.RoomserverInternalAPI, + cfg *config.Dendrite, inv invite, federation *gomatrixserverlib.FederationClient, - accountDB accounts.Database, + userAPI userapi.UserInternalAPI, ) (*gomatrixserverlib.Event, error) { verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} verRes := api.QueryRoomVersionForRoomResponse{} @@ -225,14 +223,17 @@ func createInviteFrom3PIDInvite( StateKey: &inv.MXID, } - profile, err := appserviceAPI.RetrieveUserProfile(ctx, inv.MXID, asAPI, accountDB) + var res userapi.QueryProfileResponse + err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{ + UserID: inv.MXID, + }, &res) if err != nil { return nil, err } content := gomatrixserverlib.MemberContent{ - AvatarURL: profile.AvatarURL, - DisplayName: profile.DisplayName, + AvatarURL: res.AvatarURL, + DisplayName: res.DisplayName, Membership: gomatrixserverlib.Invite, ThirdPartyInvite: &gomatrixserverlib.MemberThirdPartyInvite{ Signed: inv.Signed, @@ -261,7 +262,7 @@ func createInviteFrom3PIDInvite( // Returns an error if something failed during the process. func buildMembershipEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, rsAPI roomserverAPI.RoomserverInternalAPI, + builder *gomatrixserverlib.EventBuilder, rsAPI api.RoomserverInternalAPI, cfg *config.Dendrite, ) (*gomatrixserverlib.Event, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) @@ -274,11 +275,11 @@ func buildMembershipEvent( } // Ask the roomserver for information about this room - queryReq := roomserverAPI.QueryLatestEventsAndStateRequest{ + queryReq := api.QueryLatestEventsAndStateRequest{ RoomID: builder.RoomID, StateToFetch: eventsNeeded.Tuples(), } - var queryRes roomserverAPI.QueryLatestEventsAndStateResponse + var queryRes api.QueryLatestEventsAndStateResponse if err = rsAPI.QueryLatestEventsAndState(ctx, &queryReq, &queryRes); err != nil { return nil, err } diff --git a/federationsender/api/api.go b/federationsender/api/api.go index d8251b54d..02c762582 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -15,14 +15,6 @@ type FederationSenderInternalAPI interface { request *PerformDirectoryLookupRequest, response *PerformDirectoryLookupResponse, ) error - // Query the joined hosts and the membership events accounting for their participation in a room. - // Note that if a server has multiple users in the room, it will have multiple entries in the returned slice. - // See `QueryJoinedHostServerNamesInRoom` for a de-duplicated version. - QueryJoinedHostsInRoom( - ctx context.Context, - request *QueryJoinedHostsInRoomRequest, - response *QueryJoinedHostsInRoomResponse, - ) error // Query the server names of the joined hosts in a room. // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice // containing only the server names (without information for membership events). @@ -88,22 +80,12 @@ type PerformServersAliveRequest struct { type PerformServersAliveResponse struct { } -// QueryJoinedHostsInRoomRequest is a request to QueryJoinedHostsInRoom -type QueryJoinedHostsInRoomRequest struct { - RoomID string `json:"room_id"` -} - -// QueryJoinedHostsInRoomResponse is a response to QueryJoinedHostsInRoom -type QueryJoinedHostsInRoomResponse struct { - JoinedHosts []types.JoinedHost `json:"joined_hosts"` -} - -// QueryJoinedHostServerNamesRequest is a request to QueryJoinedHostServerNames +// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomRequest struct { RoomID string `json:"room_id"` } -// QueryJoinedHostServerNamesResponse is a response to QueryJoinedHostServerNames +// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomResponse struct { ServerNames []gomatrixserverlib.ServerName `json:"server_names"` } diff --git a/federationsender/consumers/roomserver.go b/federationsender/consumers/roomserver.go index 5f8a555bd..299c7b37a 100644 --- a/federationsender/consumers/roomserver.go +++ b/federationsender/consumers/roomserver.go @@ -86,11 +86,6 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { switch output.Type { case api.OutputTypeNewRoomEvent: ev := &output.NewRoomEvent.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "send_as_server": output.NewRoomEvent.SendAsServer, - }).Info("received room event from roomserver") if err := s.processMessage(*output.NewRoomEvent); err != nil { // panic rather than continue with an inconsistent database @@ -131,11 +126,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error { - addsStateEvents, err := s.lookupStateEvents(ore.AddsStateEventIDs, ore.Event.Event) - if err != nil { - return err - } - addsJoinedHosts, err := joinedHostsFromEvents(addsStateEvents) + addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(ore.AddsState())) if err != nil { return err } diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go index 621920ef5..10ac51c8a 100644 --- a/federationsender/federationsender.go +++ b/federationsender/federationsender.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/types" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -38,7 +38,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.FederationSenderInternalAP // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *basecomponent.BaseDendrite, + base *setup.BaseDendrite, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, keyRing *gomatrixserverlib.KeyRing, diff --git a/federationsender/internal/query.go b/federationsender/internal/query.go index 88dd50a62..253400a2d 100644 --- a/federationsender/internal/query.go +++ b/federationsender/internal/query.go @@ -7,16 +7,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// QueryJoinedHostsInRoom implements api.FederationSenderInternalAPI -func (f *FederationSenderInternalAPI) QueryJoinedHostsInRoom( - ctx context.Context, - request *api.QueryJoinedHostsInRoomRequest, - response *api.QueryJoinedHostsInRoomResponse, -) (err error) { - response.JoinedHosts, err = f.db.GetJoinedHosts(ctx, request.RoomID) - return -} - // QueryJoinedHostServerNamesInRoom implements api.FederationSenderInternalAPI func (f *FederationSenderInternalAPI) QueryJoinedHostServerNamesInRoom( ctx context.Context, diff --git a/federationsender/inthttp/client.go b/federationsender/inthttp/client.go index 8b3a106d8..5da4b35f9 100644 --- a/federationsender/inthttp/client.go +++ b/federationsender/inthttp/client.go @@ -6,13 +6,12 @@ import ( "net/http" "github.com/matrix-org/dendrite/federationsender/api" - internalHTTP "github.com/matrix-org/dendrite/internal/http" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/opentracing/opentracing-go" ) // HTTP paths for the internal HTTP API const ( - FederationSenderQueryJoinedHostsInRoomPath = "/federationsender/queryJoinedHostsInRoom" FederationSenderQueryJoinedHostServerNamesInRoomPath = "/federationsender/queryJoinedHostServerNamesInRoom" FederationSenderPerformDirectoryLookupRequestPath = "/federationsender/performDirectoryLookup" @@ -45,7 +44,7 @@ func (h *httpFederationSenderInternalAPI) PerformLeave( defer span.Finish() apiURL := h.federationSenderURL + FederationSenderPerformLeaveRequestPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } func (h *httpFederationSenderInternalAPI) PerformServersAlive( @@ -57,7 +56,7 @@ func (h *httpFederationSenderInternalAPI) PerformServersAlive( defer span.Finish() apiURL := h.federationSenderURL + FederationSenderPerformServersAlivePath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryJoinedHostServerNamesInRoom implements FederationSenderInternalAPI @@ -70,20 +69,7 @@ func (h *httpFederationSenderInternalAPI) QueryJoinedHostServerNamesInRoom( defer span.Finish() apiURL := h.federationSenderURL + FederationSenderQueryJoinedHostServerNamesInRoomPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryJoinedHostsInRoom implements FederationSenderInternalAPI -func (h *httpFederationSenderInternalAPI) QueryJoinedHostsInRoom( - ctx context.Context, - request *api.QueryJoinedHostsInRoomRequest, - response *api.QueryJoinedHostsInRoomResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostsInRoom") - defer span.Finish() - - apiURL := h.federationSenderURL + FederationSenderQueryJoinedHostsInRoomPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // Handle an instruction to make_join & send_join with a remote server. @@ -96,7 +82,7 @@ func (h *httpFederationSenderInternalAPI) PerformJoin( defer span.Finish() apiURL := h.federationSenderURL + FederationSenderPerformJoinRequestPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // Handle an instruction to make_join & send_join with a remote server. @@ -109,5 +95,5 @@ func (h *httpFederationSenderInternalAPI) PerformDirectoryLookup( defer span.Finish() apiURL := h.federationSenderURL + FederationSenderPerformDirectoryLookupRequestPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } diff --git a/federationsender/inthttp/server.go b/federationsender/inthttp/server.go index c94a14295..babd3ae13 100644 --- a/federationsender/inthttp/server.go +++ b/federationsender/inthttp/server.go @@ -6,29 +6,15 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/util" ) // AddRoutes adds the FederationSenderInternalAPI handlers to the http.ServeMux. func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Router) { - internalAPIMux.Handle( - FederationSenderQueryJoinedHostsInRoomPath, - internal.MakeInternalAPI("QueryJoinedHostsInRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryJoinedHostsInRoomRequest - var response api.QueryJoinedHostsInRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := intAPI.QueryJoinedHostsInRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) internalAPIMux.Handle( FederationSenderQueryJoinedHostServerNamesInRoomPath, - internal.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse { var request api.QueryJoinedHostServerNamesInRoomRequest var response api.QueryJoinedHostServerNamesInRoomResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -41,7 +27,7 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route }), ) internalAPIMux.Handle(FederationSenderPerformJoinRequestPath, - internal.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse { var request api.PerformJoinRequest var response api.PerformJoinResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -54,7 +40,7 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route }), ) internalAPIMux.Handle(FederationSenderPerformLeaveRequestPath, - internal.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse { var request api.PerformLeaveRequest var response api.PerformLeaveResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -67,7 +53,7 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route }), ) internalAPIMux.Handle(FederationSenderPerformDirectoryLookupRequestPath, - internal.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse { var request api.PerformDirectoryLookupRequest var response api.PerformDirectoryLookupResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -80,7 +66,7 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route }), ) internalAPIMux.Handle(FederationSenderPerformServersAlivePath, - internal.MakeInternalAPI("PerformServersAliveRequest", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("PerformServersAliveRequest", func(req *http.Request) util.JSONResponse { var request api.PerformServersAliveRequest var response api.PerformServersAliveResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 5b8fc3c56..240343559 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -107,6 +107,9 @@ func (oqs *OutgoingQueues) SendEvent( // Remove our own server from the list of destinations. destinations = filterAndDedupeDests(oqs.origin, destinations) + if len(destinations) == 0 { + return nil + } log.WithFields(log.Fields{ "destinations": destinations, "event": ev.EventID(), diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index 1b898f485..c0f9a7d5b 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -85,7 +86,7 @@ func (s *joinedHostsStatements) insertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - stmt := internal.TxStmt(txn, s.insertJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) return err } @@ -93,7 +94,7 @@ func (s *joinedHostsStatements) insertJoinedHosts( func (s *joinedHostsStatements) deleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { - stmt := internal.TxStmt(txn, s.deleteJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) return err } @@ -101,7 +102,7 @@ func (s *joinedHostsStatements) deleteJoinedHosts( func (s *joinedHostsStatements) selectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { - stmt := internal.TxStmt(txn, s.selectJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) return joinedHostsFromStmt(ctx, stmt, roomID) } diff --git a/federationsender/storage/postgres/room_table.go b/federationsender/storage/postgres/room_table.go index c945475ba..e5266c635 100644 --- a/federationsender/storage/postgres/room_table.go +++ b/federationsender/storage/postgres/room_table.go @@ -19,7 +19,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) const roomSchema = ` @@ -71,7 +71,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { func (s *roomStatements) insertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - _, err := internal.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) + _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) return err } @@ -82,7 +82,7 @@ func (s *roomStatements) selectRoomForUpdate( ctx context.Context, txn *sql.Tx, roomID string, ) (string, error) { var lastEventID string - stmt := internal.TxStmt(txn, s.selectRoomForUpdateStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomForUpdateStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID) if err != nil { return "", err @@ -95,7 +95,7 @@ func (s *roomStatements) selectRoomForUpdate( func (s *roomStatements) updateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { - stmt := internal.TxStmt(txn, s.updateRoomStmt) + stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) _, err := stmt.ExecContext(ctx, roomID, lastEventID) return err } diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index a4733a20c..8fd4c11a3 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/federationsender/types" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" ) @@ -28,12 +27,12 @@ import ( type Database struct { joinedHostsStatements roomStatements - internal.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements db *sql.DB } // NewDatabase opens a new database -func NewDatabase(dataSourceName string, dbProperties internal.DbProperties) (*Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { var result Database var err error if result.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { @@ -70,7 +69,7 @@ func (d *Database) UpdateRoom( addHosts []types.JoinedHost, removeHosts []string, ) (joinedHosts []types.JoinedHost, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = d.insertRoom(ctx, txn, roomID) if err != nil { return err diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 795d33208..d9824658c 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -84,7 +85,7 @@ func (s *joinedHostsStatements) insertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - stmt := internal.TxStmt(txn, s.insertJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) return err } @@ -93,7 +94,7 @@ func (s *joinedHostsStatements) deleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { for _, eventID := range eventIDs { - stmt := internal.TxStmt(txn, s.deleteJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) if _, err := stmt.ExecContext(ctx, eventID); err != nil { return err } @@ -104,7 +105,7 @@ func (s *joinedHostsStatements) deleteJoinedHosts( func (s *joinedHostsStatements) selectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { - stmt := internal.TxStmt(txn, s.selectJoinedHostsStmt) + stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) return joinedHostsFromStmt(ctx, stmt, roomID) } diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go index 4bf9ab81d..ca0c4d0b6 100644 --- a/federationsender/storage/sqlite3/room_table.go +++ b/federationsender/storage/sqlite3/room_table.go @@ -19,7 +19,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) const roomSchema = ` @@ -71,7 +71,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { func (s *roomStatements) insertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - _, err := internal.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) + _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) return err } @@ -82,7 +82,7 @@ func (s *roomStatements) selectRoomForUpdate( ctx context.Context, txn *sql.Tx, roomID string, ) (string, error) { var lastEventID string - stmt := internal.TxStmt(txn, s.selectRoomForUpdateStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomForUpdateStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID) if err != nil { return "", err @@ -95,7 +95,7 @@ func (s *roomStatements) selectRoomForUpdate( func (s *roomStatements) updateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { - stmt := internal.TxStmt(txn, s.updateRoomStmt) + stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) _, err := stmt.ExecContext(ctx, roomID, lastEventID) return err } diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 8699fc19a..ac303f646 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -22,7 +22,6 @@ import ( _ "github.com/mattn/go-sqlite3" "github.com/matrix-org/dendrite/federationsender/types" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" ) @@ -30,7 +29,7 @@ import ( type Database struct { joinedHostsStatements roomStatements - internal.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements db *sql.DB } @@ -42,7 +41,7 @@ func NewDatabase(dataSourceName string) (*Database, error) { if err != nil { return nil, err } - if result.db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil); err != nil { + if result.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } if err = result.prepare(); err != nil { @@ -76,7 +75,7 @@ func (d *Database) UpdateRoom( addHosts []types.JoinedHost, removeHosts []string, ) (joinedHosts []types.JoinedHost, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = d.insertRoom(ctx, txn, roomID) if err != nil { return err diff --git a/federationsender/storage/storage.go b/federationsender/storage/storage.go index df8433eba..d37360056 100644 --- a/federationsender/storage/storage.go +++ b/federationsender/storage/storage.go @@ -21,11 +21,11 @@ import ( "github.com/matrix-org/dendrite/federationsender/storage/postgres" "github.com/matrix-org/dendrite/federationsender/storage/sqlite3" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) // NewDatabase opens a new database -func NewDatabase(dataSourceName string, dbProperties internal.DbProperties) (Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return postgres.NewDatabase(dataSourceName, dbProperties) diff --git a/federationsender/storage/storage_wasm.go b/federationsender/storage/storage_wasm.go index f593fd448..e5c8f293b 100644 --- a/federationsender/storage/storage_wasm.go +++ b/federationsender/storage/storage_wasm.go @@ -19,13 +19,13 @@ import ( "net/url" "github.com/matrix-org/dendrite/federationsender/storage/sqlite3" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) // NewDatabase opens a new database func NewDatabase( dataSourceName string, - dbProperties internal.DbProperties, // nolint:unparam + dbProperties sqlutil.DbProperties, // nolint:unparam ) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { diff --git a/go.mod b/go.mod index a05bb71e2..ff73ad078 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 - github.com/matrix-org/gomatrixserverlib v0.0.0-20200608125510-defe251235b1 + github.com/matrix-org/gomatrixserverlib v0.0.0-20200617141855-5539854e4abc github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/mattn/go-sqlite3 v2.0.2+incompatible diff --git a/go.sum b/go.sum index 00d25a3f6..f62b14b62 100644 --- a/go.sum +++ b/go.sum @@ -131,7 +131,6 @@ github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uG github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hjson/hjson-go v3.0.1-0.20190209023717-9147687966d9+incompatible/go.mod h1:qsetwF8NlsTsOTwZTApNlTCerV+b2GjYRRcIk4JMFio= github.com/hjson/hjson-go v3.0.2-0.20200316202735-d5d0e8b0617d+incompatible/go.mod h1:qsetwF8NlsTsOTwZTApNlTCerV+b2GjYRRcIk4JMFio= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -372,8 +371,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 h1:Yb+Wlf github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200608125510-defe251235b1 h1:BfrvDrbjoPBvYua/3F/FmrqiZTRGrvtoMRgCVnrufMI= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200608125510-defe251235b1/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200616150727-7ac22b6f8e65 h1:2CcCcBnWdDPDOqFKiGOM+mi/KDDZXSTKmvFy/0/+ZJI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200616150727-7ac22b6f8e65/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200617141855-5539854e4abc h1:Oxidxr/H1Nh8aOFWAhTzQYLDOc9OuoJwfDjgEHpyqNE= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200617141855-5539854e4abc/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= @@ -566,8 +567,6 @@ github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhe github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yggdrasil-network/yggdrasil-extras v0.0.0-20200525205615-6c8a4a2e8855/go.mod h1:xQdsh08Io6nV4WRnOVTe6gI8/2iTvfLDQ0CYa5aMt+I= -github.com/yggdrasil-network/yggdrasil-go v0.3.14 h1:vWzYzCQxOruS+J5FkLfXOS0JhCJx1yI9Erj/h2wfZ/E= -github.com/yggdrasil-network/yggdrasil-go v0.3.14/go.mod h1:rkQzLzVHlFdzsEMG+bDdTI+KeWPCZq1HpXRFzwinf6M= github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200530233943-aec82d7a391b h1:ELOisSxFXCcptRs4LFub+Hz5fYUvV12wZrTps99Eb3E= github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200530233943-aec82d7a391b/go.mod h1:d+Nz6SPeG6kmeSPFL0cvfWfgwEql75fUnZiAONgvyBE= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= @@ -670,6 +669,7 @@ golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20181130052023-1c3d964395ce/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd h1:/e+gpKk9r3dJobndpTytxS2gOy6m5uvpg+ISQoEcusQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go index 8c71ffbd9..4697fb4d2 100644 --- a/internal/caching/cache_serverkeys.go +++ b/internal/caching/cache_serverkeys.go @@ -15,17 +15,32 @@ const ( // ServerKeyCache contains the subset of functions needed for // a server key cache. type ServerKeyCache interface { - GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest) (response gomatrixserverlib.PublicKeyLookupResult, ok bool) + // request -> timestamp is emulating gomatrixserverlib.FetchKeys: + // https://github.com/matrix-org/gomatrixserverlib/blob/f69539c86ea55d1e2cc76fd8e944e2d82d30397c/keyring.go#L95 + // The timestamp should be the timestamp of the event that is being + // verified. We will not return keys from the cache that are not valid + // at this timestamp. + GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest, timestamp gomatrixserverlib.Timestamp) (response gomatrixserverlib.PublicKeyLookupResult, ok bool) + + // request -> result is emulating gomatrixserverlib.StoreKeys: + // https://github.com/matrix-org/gomatrixserverlib/blob/f69539c86ea55d1e2cc76fd8e944e2d82d30397c/keyring.go#L112 StoreServerKey(request gomatrixserverlib.PublicKeyLookupRequest, response gomatrixserverlib.PublicKeyLookupResult) } func (c Caches) GetServerKey( request gomatrixserverlib.PublicKeyLookupRequest, + timestamp gomatrixserverlib.Timestamp, ) (gomatrixserverlib.PublicKeyLookupResult, bool) { key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) val, found := c.ServerKeys.Get(key) if found && val != nil { if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok { + if !keyLookupResult.WasValidAt(timestamp, true) { + // The key wasn't valid at the requested timestamp so don't + // return it. The caller will have to work out what to do. + c.ServerKeys.Unset(key) + return gomatrixserverlib.PublicKeyLookupResult{}, false + } return keyLookupResult, true } } diff --git a/internal/caching/caches.go b/internal/caching/caches.go index 70f380ba5..419623e27 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -12,4 +12,5 @@ type Caches struct { type Cache interface { Get(key string) (value interface{}, ok bool) Set(key string, value interface{}) + Unset(key string) } diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index f7901d2e5..7bb791dd8 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -8,11 +8,12 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) -func NewInMemoryLRUCache() (*Caches, error) { +func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { roomVersions, err := NewInMemoryLRUCachePartition( RoomVersionCacheName, RoomVersionCacheMutable, RoomVersionCacheMaxEntries, + enablePrometheus, ) if err != nil { return nil, err @@ -21,6 +22,7 @@ func NewInMemoryLRUCache() (*Caches, error) { ServerKeyCacheName, ServerKeyCacheMutable, ServerKeyCacheMaxEntries, + enablePrometheus, ) if err != nil { return nil, err @@ -38,7 +40,7 @@ type InMemoryLRUCachePartition struct { lru *lru.Cache } -func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int) (*InMemoryLRUCachePartition, error) { +func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) { var err error cache := InMemoryLRUCachePartition{ name: name, @@ -49,13 +51,15 @@ func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int) (*I if err != nil { return nil, err } - promauto.NewGaugeFunc(prometheus.GaugeOpts{ - Namespace: "dendrite", - Subsystem: "caching_in_memory_lru", - Name: name, - }, func() float64 { - return float64(cache.lru.Len()) - }) + if enablePrometheus { + promauto.NewGaugeFunc(prometheus.GaugeOpts{ + Namespace: "dendrite", + Subsystem: "caching_in_memory_lru", + Name: name, + }, func() float64 { + return float64(cache.lru.Len()) + }) + } return &cache, nil } @@ -68,6 +72,13 @@ func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) { c.lru.Add(key, value) } +func (c *InMemoryLRUCachePartition) Unset(key string) { + if !c.mutable { + panic(fmt.Sprintf("invalid use of immutable cache tries to unset value of %q", key)) + } + c.lru.Remove(key) +} + func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) { return c.lru.Get(key) } diff --git a/internal/config/config.go b/internal/config/config.go index bff4945be..2bd56ad9e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -241,6 +241,7 @@ type Dendrite struct { ServerKeyAPI Address `yaml:"server_key_api"` AppServiceAPI Address `yaml:"appservice_api"` SyncAPI Address `yaml:"sync_api"` + UserAPI Address `yaml:"user_api"` RoomServer Address `yaml:"room_server"` FederationSender Address `yaml:"federation_sender"` PublicRoomsAPI Address `yaml:"public_rooms_api"` @@ -610,6 +611,7 @@ func (config *Dendrite) checkListen(configErrs *configErrors) { checkNotEmpty(configErrs, "listen.room_server", string(config.Listen.RoomServer)) checkNotEmpty(configErrs, "listen.edu_server", string(config.Listen.EDUServer)) checkNotEmpty(configErrs, "listen.server_key_api", string(config.Listen.EDUServer)) + checkNotEmpty(configErrs, "listen.user_api", string(config.Listen.UserAPI)) } // checkLogging verifies the parameters logging.* are valid. @@ -723,6 +725,15 @@ func (config *Dendrite) RoomServerURL() string { return "http://" + string(config.Listen.RoomServer) } +// UserAPIURL returns an HTTP URL for where the userapi is listening. +func (config *Dendrite) UserAPIURL() string { + // Hard code the userapi to talk HTTP for now. + // If we support HTTPS we need to think of a practical way to do certificate validation. + // People setting up servers shouldn't need to get a certificate valid for the public + // internet for an internal API. + return "http://" + string(config.Listen.UserAPI) +} + // EDUServerURL returns an HTTP URL for where the EDU server is listening. func (config *Dendrite) EDUServerURL() string { // Hard code the EDU server to talk HTTP for now. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b72f5fad0..9a543e763 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -63,6 +63,7 @@ listen: media_api: "localhost:7774" appservice_api: "localhost:7777" edu_server: "localhost:7778" + user_api: "localhost:7779" logging: - type: "file" level: "info" diff --git a/internal/consumers.go b/internal/consumers.go index df68cbfaa..d7917f235 100644 --- a/internal/consumers.go +++ b/internal/consumers.go @@ -19,20 +19,13 @@ import ( "fmt" "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal/sqlutil" ) -// A PartitionOffset is the offset into a partition of the input log. -type PartitionOffset struct { - // The ID of the partition. - Partition int32 - // The offset into the partition. - Offset int64 -} - // A PartitionStorer has the storage APIs needed by the consumer. type PartitionStorer interface { // PartitionOffsets returns the offsets the consumer has reached for each partition. - PartitionOffsets(ctx context.Context, topic string) ([]PartitionOffset, error) + PartitionOffsets(ctx context.Context, topic string) ([]sqlutil.PartitionOffset, error) // SetPartitionOffset records where the consumer has reached for a partition. SetPartitionOffset(ctx context.Context, topic string, partition int32, offset int64) error } diff --git a/internal/eventcontent.go b/internal/eventutil/eventcontent.go similarity index 97% rename from internal/eventcontent.go rename to internal/eventutil/eventcontent.go index 645128369..873e20a8e 100644 --- a/internal/eventcontent.go +++ b/internal/eventutil/eventcontent.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package eventutil import "github.com/matrix-org/gomatrixserverlib" diff --git a/internal/events.go b/internal/eventutil/events.go similarity index 98% rename from internal/events.go rename to internal/eventutil/events.go index 89c82e03b..e6c7a4ff7 100644 --- a/internal/events.go +++ b/internal/eventutil/events.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package eventutil import ( "context" diff --git a/internal/types.go b/internal/eventutil/types.go similarity index 96% rename from internal/types.go rename to internal/eventutil/types.go index be2717f39..6d119ce6d 100644 --- a/internal/types.go +++ b/internal/eventutil/types.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package eventutil import ( "errors" diff --git a/internal/httpapis/paths.go b/internal/httpapis/paths.go deleted file mode 100644 index 8adec2dff..000000000 --- a/internal/httpapis/paths.go +++ /dev/null @@ -1,6 +0,0 @@ -package httpapis - -const ( - PublicPathPrefix = "/_matrix/" - InternalPathPrefix = "/api/" -) diff --git a/internal/http/http.go b/internal/httputil/http.go similarity index 69% rename from internal/http/http.go rename to internal/httputil/http.go index 2b1891403..9197371aa 100644 --- a/internal/http/http.go +++ b/internal/httputil/http.go @@ -1,4 +1,18 @@ -package http +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil import ( "bytes" @@ -9,7 +23,6 @@ import ( "net/url" "strings" - "github.com/matrix-org/dendrite/internal/httpapis" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" ) @@ -29,7 +42,7 @@ func PostJSON( return err } - parsedAPIURL.Path = httpapis.InternalPathPrefix + strings.TrimLeft(parsedAPIURL.Path, "/") + parsedAPIURL.Path = InternalPathPrefix + strings.TrimLeft(parsedAPIURL.Path, "/") apiURL = parsedAPIURL.String() req, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewReader(jsonBytes)) diff --git a/internal/httpapi.go b/internal/httputil/httpapi.go similarity index 88% rename from internal/httpapi.go rename to internal/httputil/httpapi.go index 991a98619..d371d1728 100644 --- a/internal/httpapi.go +++ b/internal/httputil/httpapi.go @@ -1,4 +1,18 @@ -package internal +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil import ( "context" @@ -13,10 +27,9 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" federationsenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" - "github.com/matrix-org/dendrite/internal/httpapis" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" opentracing "github.com/opentracing/opentracing-go" @@ -35,11 +48,11 @@ type BasicAuth struct { // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( - metricsName string, data auth.Data, - f func(*http.Request, *authtypes.Device) util.JSONResponse, + metricsName string, userAPI userapi.UserInternalAPI, + f func(*http.Request, *userapi.Device) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { - device, err := auth.VerifyUserFromRequest(req, data) + device, err := auth.VerifyUserFromRequest(req, userAPI) if err != nil { return *err } @@ -172,7 +185,7 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse func MakeFedAPI( metricsName string, serverName gomatrixserverlib.ServerName, - keyRing gomatrixserverlib.KeyRing, + keyRing gomatrixserverlib.JSONVerifier, wakeup *FederationWakeups, f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, ) http.Handler { @@ -220,16 +233,15 @@ func (f *FederationWakeups) Wakeup(ctx context.Context, origin gomatrixserverlib } } -// SetupHTTPAPI registers an HTTP API mux under /api and sets up a metrics -// listener. -func SetupHTTPAPI(servMux *http.ServeMux, publicApiMux *mux.Router, internalApiMux *mux.Router, cfg *config.Dendrite, enableHTTPAPIs bool) { +// SetupHTTPAPI registers an HTTP API mux under /api and sets up a metrics listener +func SetupHTTPAPI(servMux, publicApiMux, internalApiMux *mux.Router, cfg *config.Dendrite, enableHTTPAPIs bool) { if cfg.Metrics.Enabled { servMux.Handle("/metrics", WrapHandlerInBasicAuth(promhttp.Handler(), cfg.Metrics.BasicAuth)) } if enableHTTPAPIs { - servMux.Handle(httpapis.InternalPathPrefix, internalApiMux) + servMux.Handle(InternalPathPrefix, internalApiMux) } - servMux.Handle(httpapis.PublicPathPrefix, WrapHandlerInCORS(publicApiMux)) + servMux.Handle(PublicPathPrefix, WrapHandlerInCORS(publicApiMux)) } // WrapHandlerInBasicAuth adds basic auth to a handler. Only used for /metrics diff --git a/internal/httpapi_test.go b/internal/httputil/httpapi_test.go similarity index 75% rename from internal/httpapi_test.go rename to internal/httputil/httpapi_test.go index 6f159c8d6..de6ccf9b4 100644 --- a/internal/httpapi_test.go +++ b/internal/httputil/httpapi_test.go @@ -1,4 +1,18 @@ -package internal +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil import ( "net/http" diff --git a/internal/httputil/paths.go b/internal/httputil/paths.go new file mode 100644 index 000000000..728b5a871 --- /dev/null +++ b/internal/httputil/paths.go @@ -0,0 +1,20 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil + +const ( + PublicPathPrefix = "/_matrix/" + InternalPathPrefix = "/api/" +) diff --git a/internal/routing.go b/internal/httputil/routing.go similarity index 94% rename from internal/routing.go rename to internal/httputil/routing.go index 4462c70ca..0bd3655ec 100644 --- a/internal/routing.go +++ b/internal/httputil/routing.go @@ -1,4 +1,4 @@ -// Copyright 2019 The Matrix.org Foundation C.I.C. +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package httputil import ( "net/url" diff --git a/internal/basecomponent/base.go b/internal/setup/base.go similarity index 90% rename from internal/basecomponent/base.go rename to internal/setup/base.go index 3ad1e4af2..66424a609 100644 --- a/internal/basecomponent/base.go +++ b/internal/setup/base.go @@ -1,4 +1,4 @@ -// Copyright 2017 New Vector Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package basecomponent +package setup import ( "database/sql" @@ -23,14 +23,14 @@ import ( "time" "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/httpapis" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/naffka" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/Shopify/sarama" "github.com/gorilla/mux" @@ -46,6 +46,8 @@ import ( rsinthttp "github.com/matrix-org/dendrite/roomserver/inthttp" serverKeyAPI "github.com/matrix-org/dendrite/serverkeyapi/api" skinthttp "github.com/matrix-org/dendrite/serverkeyapi/inthttp" + userapi "github.com/matrix-org/dendrite/userapi/api" + userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/sirupsen/logrus" _ "net/http/pprof" @@ -63,6 +65,7 @@ type BaseDendrite struct { // PublicAPIMux should be used to register new public matrix api endpoints PublicAPIMux *mux.Router InternalAPIMux *mux.Router + BaseMux *mux.Router // base router which created public/internal subrouters UseHTTPAPIs bool httpClient *http.Client Cfg *config.Dendrite @@ -95,7 +98,7 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, useHTTPAPIs boo kafkaConsumer, kafkaProducer = setupKafka(cfg) } - cache, err := caching.NewInMemoryLRUCache() + cache, err := caching.NewInMemoryLRUCache(true) if err != nil { logrus.WithError(err).Warnf("Failed to create cache") } @@ -127,8 +130,9 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, useHTTPAPIs boo tracerCloser: closer, Cfg: cfg, Caches: cache, - PublicAPIMux: httpmux.PathPrefix(httpapis.PublicPathPrefix).Subrouter().UseEncodedPath(), - InternalAPIMux: httpmux.PathPrefix(httpapis.InternalPathPrefix).Subrouter().UseEncodedPath(), + BaseMux: httpmux, + PublicAPIMux: httpmux.PathPrefix(httputil.PublicPathPrefix).Subrouter().UseEncodedPath(), + InternalAPIMux: httpmux.PathPrefix(httputil.InternalPathPrefix).Subrouter().UseEncodedPath(), httpClient: &client, KafkaConsumer: kafkaConsumer, KafkaProducer: kafkaProducer, @@ -158,6 +162,15 @@ func (b *BaseDendrite) RoomserverHTTPClient() roomserverAPI.RoomserverInternalAP return rsAPI } +// UserAPIClient returns UserInternalAPI for hitting the userapi over HTTP. +func (b *BaseDendrite) UserAPIClient() userapi.UserInternalAPI { + userAPI, err := userapiinthttp.NewUserAPIClient(b.Cfg.UserAPIURL(), b.httpClient) + if err != nil { + logrus.WithError(err).Panic("UserAPIClient failed", b.httpClient) + } + return userAPI +} + // EDUServerClient returns EDUServerInputAPI for hitting the EDU server over HTTP func (b *BaseDendrite) EDUServerClient() eduServerAPI.EDUServerInputAPI { e, err := eduinthttp.NewEDUServerClient(b.Cfg.EDUServerURL(), b.httpClient) @@ -237,13 +250,14 @@ func (b *BaseDendrite) SetupAndServeHTTP(bindaddr string, listenaddr string) { WriteTimeout: HTTPServerTimeout, } - internal.SetupHTTPAPI( - http.DefaultServeMux, + httputil.SetupHTTPAPI( + b.BaseMux, b.PublicAPIMux, b.InternalAPIMux, b.Cfg, b.UseHTTPAPIs, ) + serv.Handler = b.BaseMux logrus.Infof("Starting %s server on %s", b.componentName, serv.Addr) err := serv.ListenAndServe() @@ -282,7 +296,7 @@ func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) { if err != nil { logrus.WithError(err).Panic("Failed to parse naffka database file URI") } - db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil) + db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil) if err != nil { logrus.WithError(err).Panic("Failed to open naffka database") } diff --git a/internal/basecomponent/flags.go b/internal/setup/flags.go similarity index 94% rename from internal/basecomponent/flags.go rename to internal/setup/flags.go index 117df0796..e4fc58d60 100644 --- a/internal/basecomponent/flags.go +++ b/internal/setup/flags.go @@ -1,4 +1,4 @@ -// Copyright 2017 New Vector Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package basecomponent +package setup import ( "flag" diff --git a/internal/setup/monolith.go b/internal/setup/monolith.go index 35fcd3119..24bee9502 100644 --- a/internal/setup/monolith.go +++ b/internal/setup/monolith.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package setup import ( @@ -5,8 +19,6 @@ import ( "github.com/gorilla/mux" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/federationapi" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" @@ -20,6 +32,9 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" serverKeyAPI "github.com/matrix-org/dendrite/serverkeyapi/api" "github.com/matrix-org/dendrite/syncapi" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) @@ -30,6 +45,7 @@ type Monolith struct { DeviceDB devices.Database AccountDB accounts.Database KeyRing *gomatrixserverlib.KeyRing + Client *gomatrixserverlib.Client FedClient *gomatrixserverlib.FederationClient KafkaConsumer sarama.Consumer KafkaProducer sarama.SyncProducer @@ -39,6 +55,7 @@ type Monolith struct { FederationSenderAPI federationSenderAPI.FederationSenderInternalAPI RoomserverAPI roomserverAPI.RoomserverInternalAPI ServerKeyAPI serverKeyAPI.ServerKeyInternalAPI + UserAPI userapi.UserInternalAPI // TODO: can we remove this? It's weird that we are required the database // yet every other component can do that on its own. libp2p-demo uses a custom @@ -53,23 +70,23 @@ type Monolith struct { func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) { clientapi.AddPublicRoutes( publicMux, m.Config, m.KafkaConsumer, m.KafkaProducer, m.DeviceDB, m.AccountDB, - m.FedClient, m.KeyRing, m.RoomserverAPI, + m.FedClient, m.RoomserverAPI, m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), - m.FederationSenderAPI, + m.FederationSenderAPI, m.UserAPI, ) - keyserver.AddPublicRoutes(publicMux, m.Config, m.DeviceDB, m.AccountDB) + keyserver.AddPublicRoutes(publicMux, m.Config, m.UserAPI) federationapi.AddPublicRoutes( - publicMux, m.Config, m.AccountDB, m.DeviceDB, m.FedClient, - m.KeyRing, m.RoomserverAPI, m.AppserviceAPI, m.FederationSenderAPI, + publicMux, m.Config, m.UserAPI, m.FedClient, + m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI, m.EDUInternalAPI, ) - mediaapi.AddPublicRoutes(publicMux, m.Config, m.DeviceDB) + mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client) publicroomsapi.AddPublicRoutes( - publicMux, m.Config, m.KafkaConsumer, m.DeviceDB, m.PublicRoomsDB, m.RoomserverAPI, m.FedClient, + publicMux, m.Config, m.KafkaConsumer, m.UserAPI, m.PublicRoomsDB, m.RoomserverAPI, m.FedClient, m.ExtPublicRoomsProvider, ) syncapi.AddPublicRoutes( - publicMux, m.KafkaConsumer, m.DeviceDB, m.AccountDB, m.RoomserverAPI, m.FedClient, m.Config, + publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI, m.FedClient, m.Config, ) } diff --git a/internal/partition_offset_table.go b/internal/sqlutil/partition_offset_table.go similarity index 88% rename from internal/partition_offset_table.go rename to internal/sqlutil/partition_offset_table.go index 8b72819b7..348829025 100644 --- a/internal/partition_offset_table.go +++ b/internal/sqlutil/partition_offset_table.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,14 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package sqlutil import ( "context" "database/sql" "strings" + + "github.com/matrix-org/util" ) +// A PartitionOffset is the offset into a partition of the input log. +type PartitionOffset struct { + // The ID of the partition. + Partition int32 + // The offset into the partition. + Offset int64 +} + const partitionOffsetsSchema = ` -- The offsets that the server has processed up to. CREATE TABLE IF NOT EXISTS ${prefix}_partition_offsets ( @@ -90,7 +100,12 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets( if err != nil { return nil, err } - defer CloseAndLogIfError(ctx, rows, "selectPartitionOffsets: rows.close() failed") + defer func() { + err2 := rows.Close() + if err2 != nil { + util.GetLogger(ctx).WithError(err2).Error("selectPartitionOffsets: rows.close() failed") + } + }() var results []PartitionOffset for rows.Next() { var offset PartitionOffset diff --git a/internal/postgres.go b/internal/sqlutil/postgres.go similarity index 98% rename from internal/postgres.go rename to internal/sqlutil/postgres.go index 7ae40d8fd..41a5508a1 100644 --- a/internal/postgres.go +++ b/internal/sqlutil/postgres.go @@ -14,7 +14,7 @@ // +build !wasm -package internal +package sqlutil import "github.com/lib/pq" diff --git a/internal/postgres_wasm.go b/internal/sqlutil/postgres_wasm.go similarity index 97% rename from internal/postgres_wasm.go rename to internal/sqlutil/postgres_wasm.go index 64d328291..c45842f0c 100644 --- a/internal/postgres_wasm.go +++ b/internal/sqlutil/postgres_wasm.go @@ -14,7 +14,7 @@ // +build wasm -package internal +package sqlutil // IsUniqueConstraintViolationErr no-ops for this architecture func IsUniqueConstraintViolationErr(err error) bool { diff --git a/internal/sql.go b/internal/sqlutil/sql.go similarity index 96% rename from internal/sql.go rename to internal/sqlutil/sql.go index e3c10afca..a25a4a5b6 100644 --- a/internal/sql.go +++ b/internal/sqlutil/sql.go @@ -1,6 +1,4 @@ -// Copyright 2017 Vector Creations Ltd -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package sqlutil import ( "database/sql" diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go index 1b008e1b7..f6644d591 100644 --- a/internal/sqlutil/trace.go +++ b/internal/sqlutil/trace.go @@ -25,7 +25,6 @@ import ( "strings" "time" - "github.com/matrix-org/dendrite/internal" "github.com/ngrok/sqlmw" "github.com/sirupsen/logrus" ) @@ -78,7 +77,7 @@ func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest [ // Open opens a database specified by its database driver name and a driver-specific data source name, // usually consisting of at least a database name and connection information. Includes tracing driver // if DENDRITE_TRACE_SQL=1 -func Open(driverName, dsn string, dbProperties internal.DbProperties) (*sql.DB, error) { +func Open(driverName, dsn string, dbProperties DbProperties) (*sql.DB, error) { if tracingEnabled { // install the wrapped driver driverName += "-trace" @@ -87,7 +86,7 @@ func Open(driverName, dsn string, dbProperties internal.DbProperties) (*sql.DB, if err != nil { return nil, err } - if driverName != internal.SQLiteDriverName() && dbProperties != nil { + if driverName != SQLiteDriverName() && dbProperties != nil { logrus.WithFields(logrus.Fields{ "MaxOpenConns": dbProperties.MaxOpenConns(), "MaxIdleConns": dbProperties.MaxIdleConns(), diff --git a/internal/sqlutil/uri.go b/internal/sqlutil/uri.go index f72e02426..703258e6c 100644 --- a/internal/sqlutil/uri.go +++ b/internal/sqlutil/uri.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package sqlutil import ( diff --git a/clientapi/auth/authtypes/account.go b/internal/test/keyring.go similarity index 56% rename from clientapi/auth/authtypes/account.go rename to internal/test/keyring.go index fd3c15a84..ed9c34849 100644 --- a/clientapi/auth/authtypes/account.go +++ b/internal/test/keyring.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,20 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package authtypes +package test import ( + "context" + "github.com/matrix-org/gomatrixserverlib" ) -// Account represents a Matrix account on this home server. -type Account struct { - UserID string - Localpart string - ServerName gomatrixserverlib.ServerName - Profile *Profile - AppServiceID string - // TODO: Other flags like IsAdmin, IsGuest - // TODO: Devices - // TODO: Associations (e.g. with application services) +// NopJSONVerifier is a JSONVerifier that verifies nothing and returns no errors. +type NopJSONVerifier struct { + // this verifier verifies nothing +} + +func (t *NopJSONVerifier) VerifyJSONs(ctx context.Context, requests []gomatrixserverlib.VerifyJSONRequest) ([]gomatrixserverlib.VerifyJSONResult, error) { + result := make([]gomatrixserverlib.VerifyJSONResult, len(requests)) + return result, nil } diff --git a/internal/test/server.go b/internal/test/server.go index 1493dac6f..c3348d533 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,11 +15,16 @@ package test import ( + "context" "fmt" + "net" + "net/http" "os" "os/exec" "path/filepath" "strings" + "sync" + "testing" "github.com/matrix-org/dendrite/internal/config" ) @@ -103,3 +108,46 @@ func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) { proxyArgs, ) } + +// ListenAndServe will listen on a random high-numbered port and attach the given router. +// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed. +func ListenAndServe(t *testing.T, router http.Handler, useTLS bool) (apiURL string, cancel func()) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to listen: %s", err) + } + port := listener.Addr().(*net.TCPAddr).Port + srv := http.Server{} + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + srv.Handler = router + var err error + if useTLS { + certFile := filepath.Join(os.TempDir(), "dendrite.cert") + keyFile := filepath.Join(os.TempDir(), "dendrite.key") + err = NewTLSKey(keyFile, certFile) + if err != nil { + t.Logf("failed to generate tls key/cert: %s", err) + return + } + err = srv.ServeTLS(listener, certFile, keyFile) + } else { + err = srv.Serve(listener) + } + if err != nil && err != http.ErrServerClosed { + t.Logf("Listen failed: %s", err) + } + }() + + secure := "" + if useTLS { + secure = "s" + } + return fmt.Sprintf("http%s://localhost:%d", secure, port), func() { + _ = srv.Shutdown(context.Background()) + wg.Wait() + } +} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 1eb730541..bedc4dfb8 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -16,17 +16,14 @@ package keyserver import ( "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/keyserver/routing" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // AddPublicRoutes registers HTTP handlers for CS API calls func AddPublicRoutes( - router *mux.Router, cfg *config.Dendrite, - deviceDB devices.Database, - accountsDB accounts.Database, + router *mux.Router, cfg *config.Dendrite, userAPI userapi.UserInternalAPI, ) { - routing.Setup(router, cfg, accountsDB, deviceDB) + routing.Setup(router, cfg, userAPI) } diff --git a/keyserver/routing/routing.go b/keyserver/routing/routing.go index 2d1122c62..dba43528f 100644 --- a/keyserver/routing/routing.go +++ b/keyserver/routing/routing.go @@ -18,12 +18,9 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) @@ -36,21 +33,22 @@ const pathPrefixR0 = "/client/r0" // applied: // nolint: gocyclo func Setup( - publicAPIMux *mux.Router, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + publicAPIMux *mux.Router, cfg *config.Dendrite, userAPI userapi.UserInternalAPI, ) { r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() - authData := auth.Data{ - AccountDB: accountDB, - DeviceDB: deviceDB, - AppServices: cfg.Derived.ApplicationServices, - } - r0mux.Handle("/keys/query", - internal.MakeAuthAPI("queryKeys", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + httputil.MakeAuthAPI("queryKeys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req) }), ).Methods(http.MethodPost, http.MethodOptions) + + r0mux.Handle("/keys/upload/{keyID}", + httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{}, + } + }), + ).Methods(http.MethodPost, http.MethodOptions) } diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index d4e260ea4..290ef46e1 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -16,10 +16,10 @@ package mediaapi import ( "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/mediaapi/routing" "github.com/matrix-org/dendrite/mediaapi/storage" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -27,7 +27,8 @@ import ( // AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component. func AddPublicRoutes( router *mux.Router, cfg *config.Dendrite, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, + client *gomatrixserverlib.Client, ) { mediaDB, err := storage.Open(string(cfg.Database.MediaAPI), cfg.DbProperties()) if err != nil { @@ -35,6 +36,6 @@ func AddPublicRoutes( } routing.Setup( - router, cfg, mediaDB, deviceDB, gomatrixserverlib.NewClient(), + router, cfg, mediaDB, userAPI, client, ) } diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 1a025f6f0..7e121de3e 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -21,12 +21,14 @@ import ( "io" "mime" "net/http" + "net/url" "os" "path/filepath" "regexp" "strconv" "strings" "sync" + "unicode" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/config" @@ -45,6 +47,10 @@ const mediaIDCharacters = "A-Za-z0-9_=-" // Note: unfortunately regex.MustCompile() cannot be assigned to a const var mediaIDRegex = regexp.MustCompile("^[" + mediaIDCharacters + "]+$") +// Regular expressions to help us cope with Content-Disposition parsing +var rfc2183 = regexp.MustCompile(`filename\=utf-8\"(.*)\"`) +var rfc6266 = regexp.MustCompile(`filename\*\=utf-8\'\'(.*)`) + // downloadRequest metadata included in or derivable from a download or thumbnail request // https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid // http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-thumbnail-servername-mediaid @@ -53,6 +59,7 @@ type downloadRequest struct { IsThumbnailRequest bool ThumbnailSize types.ThumbnailSize Logger *log.Entry + DownloadFilename string } // Download implements GET /download and GET /thumbnail @@ -72,6 +79,7 @@ func Download( activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, isThumbnailRequest bool, + customFilename string, ) { dReq := &downloadRequest{ MediaMetadata: &types.MediaMetadata{ @@ -83,6 +91,7 @@ func Download( "Origin": origin, "MediaID": mediaID, }), + DownloadFilename: customFilename, } if dReq.IsThumbnailRequest { @@ -300,9 +309,8 @@ func (r *downloadRequest) respondFromLocalFile( }).Info("Responding with file") responseFile = file responseMetadata = r.MediaMetadata - - if len(responseMetadata.UploadName) > 0 { - w.Header().Set("Content-Disposition", fmt.Sprintf(`inline; filename*=utf-8"%s"`, responseMetadata.UploadName)) + if err := r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil { + return nil, err } } @@ -321,6 +329,67 @@ func (r *downloadRequest) respondFromLocalFile( return responseMetadata, nil } +func (r *downloadRequest) addDownloadFilenameToHeaders( + w http.ResponseWriter, + responseMetadata *types.MediaMetadata, +) error { + // If the requestor supplied a filename to name the download then + // use that, otherwise use the filename from the response metadata. + filename := string(responseMetadata.UploadName) + if r.DownloadFilename != "" { + filename = r.DownloadFilename + } + + if len(filename) == 0 { + return nil + } + + unescaped, err := url.PathUnescape(filename) + if err != nil { + return fmt.Errorf("url.PathUnescape: %w", err) + } + + isASCII := true // Is the string ASCII or UTF-8? + quote := `` // Encloses the string (ASCII only) + for i := 0; i < len(unescaped); i++ { + if unescaped[i] > unicode.MaxASCII { + isASCII = false + } + if unescaped[i] == 0x20 || unescaped[i] == 0x3B { + // If the filename contains a space or a semicolon, which + // are special characters in Content-Disposition + quote = `"` + } + } + + // We don't necessarily want a full escape as the Content-Disposition + // can take many of the characters that PathEscape would otherwise and + // browser support for encoding is a bit wild, so we'll escape only + // the characters that we know will mess up the parsing of the + // Content-Disposition header elements themselves + unescaped = strings.ReplaceAll(unescaped, `\`, `\\"`) + unescaped = strings.ReplaceAll(unescaped, `"`, `\"`) + + if isASCII { + // For ASCII filenames, we should only quote the filename if + // it needs to be done, e.g. it contains a space or a character + // that would otherwise be parsed as a control character in the + // Content-Disposition header + w.Header().Set("Content-Disposition", fmt.Sprintf( + `inline; filename=%s%s%s`, + quote, unescaped, quote, + )) + } else { + // For UTF-8 filenames, we quote always, as that's the standard + w.Header().Set("Content-Disposition", fmt.Sprintf( + `inline; filename*=utf-8''%s`, + url.QueryEscape(unescaped), + )) + } + + return nil +} + // Note: Thumbnail generation may be ongoing asynchronously. // If no thumbnail was found then returns nil, nil, nil func (r *downloadRequest) getThumbnailFile( @@ -635,9 +704,22 @@ func (r *downloadRequest) fetchRemoteFile( } r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength) r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) - _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) - if err == nil && params["filename"] != "" { - r.MediaMetadata.UploadName = types.Filename(params["filename"]) + + dispositionHeader := resp.Header.Get("Content-Disposition") + if _, params, e := mime.ParseMediaType(dispositionHeader); e == nil { + if params["filename"] != "" { + r.MediaMetadata.UploadName = types.Filename(params["filename"]) + } else if params["filename*"] != "" { + r.MediaMetadata.UploadName = types.Filename(params["filename*"]) + } + } else { + if matches := rfc6266.FindStringSubmatch(dispositionHeader); len(matches) > 1 { + // Always prefer the RFC6266 UTF-8 name if possible + r.MediaMetadata.UploadName = types.Filename(matches[1]) + } else if matches := rfc2183.FindStringSubmatch(dispositionHeader); len(matches) > 1 { + // Otherwise, see if an RFC2183 name was provided (ASCII only) + r.MediaMetadata.UploadName = types.Filename(matches[1]) + } } r.Logger.Info("Transferring remote file") diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 11e4ecad5..bc0de0f45 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -16,14 +16,13 @@ package routing import ( "net/http" + "strings" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -34,6 +33,7 @@ import ( ) const pathPrefixR0 = "/media/r0" +const pathPrefixV1 = "/media/v1" // TODO: remove when synapse is fixed // Setup registers the media API HTTP handlers // @@ -44,24 +44,18 @@ func Setup( publicAPIMux *mux.Router, cfg *config.Dendrite, db storage.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, client *gomatrixserverlib.Client, ) { r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() + v1mux := publicAPIMux.PathPrefix(pathPrefixV1).Subrouter() activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ PathToResult: map[string]*types.ThumbnailGenerationResult{}, } - authData := auth.Data{ - AccountDB: nil, - DeviceDB: deviceDB, - AppServices: nil, - } - - // TODO: Add AS support - r0mux.Handle("/upload", internal.MakeAuthAPI( - "upload", authData, - func(req *http.Request, _ *authtypes.Device) util.JSONResponse { + r0mux.Handle("/upload", httputil.MakeAuthAPI( + "upload", userAPI, + func(req *http.Request, _ *userapi.Device) util.JSONResponse { return Upload(req, cfg, db, activeThumbnailGeneration) }, )).Methods(http.MethodPost, http.MethodOptions) @@ -69,9 +63,13 @@ func Setup( activeRemoteRequests := &types.ActiveRemoteRequests{ MXCToResult: map[string]*types.RemoteRequestResult{}, } - r0mux.Handle("/download/{serverName}/{mediaId}", - makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration), - ).Methods(http.MethodGet, http.MethodOptions) + + downloadHandler := makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration) + r0mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) + v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed + v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) // TODO: remove when synapse is fixed + r0mux.Handle("/thumbnail/{serverName}/{mediaId}", makeDownloadAPI("thumbnail", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration), ).Methods(http.MethodGet, http.MethodOptions) @@ -99,11 +97,24 @@ func makeDownloadAPI( util.SetCORSHeaders(w) // Content-Type will be overridden in case of returning file data, else we respond with JSON-formatted errors w.Header().Set("Content-Type", "application/json") - vars, _ := internal.URLDecodeMapValues(mux.Vars(req)) + + vars, _ := httputil.URLDecodeMapValues(mux.Vars(req)) + serverName := gomatrixserverlib.ServerName(vars["serverName"]) + + // For the purposes of loop avoidance, we will return a 404 if allow_remote is set to + // false in the query string and the target server name isn't our own. + // https://github.com/matrix-org/matrix-doc/pull/1265 + if allowRemote := req.URL.Query().Get("allow_remote"); strings.ToLower(allowRemote) == "false" { + if serverName != cfg.Matrix.ServerName { + w.WriteHeader(http.StatusNotFound) + return + } + } + Download( w, req, - gomatrixserverlib.ServerName(vars["serverName"]), + serverName, types.MediaID(vars["mediaId"]), cfg, db, @@ -111,6 +122,7 @@ func makeDownloadAPI( activeRemoteRequests, activeThumbnailGeneration, name == "thumbnail", + vars["downloadName"], ) } return promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 022f978d6..9b5dc3df8 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -16,6 +16,7 @@ package routing import ( "context" + "encoding/base64" "fmt" "io" "net/http" @@ -123,7 +124,9 @@ func (r *uploadRequest) doUpload( r.MediaMetadata.FileSizeBytes = bytesWritten r.MediaMetadata.Base64Hash = hash - r.MediaMetadata.MediaID = types.MediaID(hash) + r.MediaMetadata.MediaID = types.MediaID(base64.RawURLEncoding.EncodeToString( + []byte(string(r.MediaMetadata.UploadName) + string(r.MediaMetadata.Base64Hash)), + )) r.Logger = r.Logger.WithField("MediaID", r.MediaMetadata.MediaID) diff --git a/mediaapi/storage/postgres/storage.go b/mediaapi/storage/postgres/storage.go index 72ee05cd8..e45e08416 100644 --- a/mediaapi/storage/postgres/storage.go +++ b/mediaapi/storage/postgres/storage.go @@ -21,7 +21,6 @@ import ( // Import the postgres database driver. _ "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -34,7 +33,7 @@ type Database struct { } // Open opens a postgres database. -func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, error) { +func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { var d Database var err error if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { diff --git a/mediaapi/storage/postgres/thumbnail_table.go b/mediaapi/storage/postgres/thumbnail_table.go index dbc97cf22..3f28cdbbf 100644 --- a/mediaapi/storage/postgres/thumbnail_table.go +++ b/mediaapi/storage/postgres/thumbnail_table.go @@ -21,7 +21,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/mediaapi/storage/sqlite3/storage.go b/mediaapi/storage/sqlite3/storage.go index ea81f9120..010c0a66e 100644 --- a/mediaapi/storage/sqlite3/storage.go +++ b/mediaapi/storage/sqlite3/storage.go @@ -20,7 +20,6 @@ import ( "database/sql" // Import the postgres database driver. - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -41,7 +40,7 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } - if d.db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil); err != nil { + if d.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } if err = d.statements.prepare(d.db); err != nil { diff --git a/mediaapi/storage/sqlite3/thumbnail_table.go b/mediaapi/storage/sqlite3/thumbnail_table.go index a0956f68c..432a1590c 100644 --- a/mediaapi/storage/sqlite3/thumbnail_table.go +++ b/mediaapi/storage/sqlite3/thumbnail_table.go @@ -21,7 +21,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/mediaapi/storage/storage.go b/mediaapi/storage/storage.go index 2694e197d..5ff114db6 100644 --- a/mediaapi/storage/storage.go +++ b/mediaapi/storage/storage.go @@ -19,13 +19,13 @@ package storage import ( "net/url" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/postgres" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" ) // Open opens a postgres database. -func Open(dataSourceName string, dbProperties internal.DbProperties) (Database, error) { +func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return postgres.Open(dataSourceName, dbProperties) diff --git a/mediaapi/storage/storage_wasm.go b/mediaapi/storage/storage_wasm.go index 78de2cb82..a672271f9 100644 --- a/mediaapi/storage/storage_wasm.go +++ b/mediaapi/storage/storage_wasm.go @@ -18,14 +18,14 @@ import ( "fmt" "net/url" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" ) // Open opens a postgres database. func Open( dataSourceName string, - dbProperties internal.DbProperties, // nolint:unparam + dbProperties sqlutil.DbProperties, // nolint:unparam ) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { diff --git a/publicroomsapi/consumers/roomserver.go b/publicroomsapi/consumers/roomserver.go index c513d3b24..b9686d56d 100644 --- a/publicroomsapi/consumers/roomserver.go +++ b/publicroomsapi/consumers/roomserver.go @@ -78,29 +78,17 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { return nil } - ev := output.NewRoomEvent.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "type": ev.Type(), - }).Info("received event from roomserver") - - addQueryReq := api.QueryEventsByIDRequest{EventIDs: output.NewRoomEvent.AddsStateEventIDs} - var addQueryRes api.QueryEventsByIDResponse - if err := s.rsAPI.QueryEventsByID(context.TODO(), &addQueryReq, &addQueryRes); err != nil { - log.Warn(err) - return err - } - - remQueryReq := api.QueryEventsByIDRequest{EventIDs: output.NewRoomEvent.RemovesStateEventIDs} var remQueryRes api.QueryEventsByIDResponse - if err := s.rsAPI.QueryEventsByID(context.TODO(), &remQueryReq, &remQueryRes); err != nil { - log.Warn(err) - return err + if len(output.NewRoomEvent.RemovesStateEventIDs) > 0 { + remQueryReq := api.QueryEventsByIDRequest{EventIDs: output.NewRoomEvent.RemovesStateEventIDs} + if err := s.rsAPI.QueryEventsByID(context.TODO(), &remQueryReq, &remQueryRes); err != nil { + log.Warn(err) + return err + } } var addQueryEvents, remQueryEvents []gomatrixserverlib.Event - for _, headeredEvent := range addQueryRes.Events { + for _, headeredEvent := range output.NewRoomEvent.AddsState() { addQueryEvents = append(addQueryEvents, headeredEvent.Event) } for _, headeredEvent := range remQueryRes.Events { diff --git a/publicroomsapi/directory/directory.go b/publicroomsapi/directory/directory.go index fe7a67932..8b68279aa 100644 --- a/publicroomsapi/directory/directory.go +++ b/publicroomsapi/directory/directory.go @@ -17,8 +17,8 @@ package directory import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -59,7 +59,7 @@ func GetVisibility( // SetVisibility implements PUT /directory/list/room/{roomID} // TODO: Allow admin users to edit the room visibility func SetVisibility( - req *http.Request, publicRoomsDatabase storage.Database, rsAPI api.RoomserverInternalAPI, dev *authtypes.Device, + req *http.Request, publicRoomsDatabase storage.Database, rsAPI api.RoomserverInternalAPI, dev *userapi.Device, roomID string, ) util.JSONResponse { queryMembershipReq := api.QueryMembershipForUserRequest{ diff --git a/publicroomsapi/publicroomsapi.go b/publicroomsapi/publicroomsapi.go index 1f98a4e05..b9baa1056 100644 --- a/publicroomsapi/publicroomsapi.go +++ b/publicroomsapi/publicroomsapi.go @@ -17,13 +17,13 @@ package publicroomsapi import ( "github.com/Shopify/sarama" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/publicroomsapi/consumers" "github.com/matrix-org/dendrite/publicroomsapi/routing" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/publicroomsapi/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -34,7 +34,7 @@ func AddPublicRoutes( router *mux.Router, cfg *config.Dendrite, consumer sarama.Consumer, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, publicRoomsDB storage.Database, rsAPI roomserverAPI.RoomserverInternalAPI, fedClient *gomatrixserverlib.FederationClient, @@ -47,5 +47,5 @@ func AddPublicRoutes( logrus.WithError(err).Panic("failed to start public rooms server consumer") } - routing.Setup(router, deviceDB, publicRoomsDB, rsAPI, fedClient, extRoomsProvider) + routing.Setup(router, userAPI, publicRoomsDB, rsAPI, fedClient, extRoomsProvider) } diff --git a/publicroomsapi/routing/routing.go b/publicroomsapi/routing/routing.go index ee1434392..9c82d3508 100644 --- a/publicroomsapi/routing/routing.go +++ b/publicroomsapi/routing/routing.go @@ -17,13 +17,11 @@ package routing import ( "net/http" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/publicroomsapi/directory" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/publicroomsapi/types" @@ -39,20 +37,14 @@ const pathPrefixR0 = "/client/r0" // applied: // nolint: gocyclo func Setup( - publicAPIMux *mux.Router, deviceDB devices.Database, publicRoomsDB storage.Database, rsAPI api.RoomserverInternalAPI, + publicAPIMux *mux.Router, userAPI userapi.UserInternalAPI, publicRoomsDB storage.Database, rsAPI api.RoomserverInternalAPI, fedClient *gomatrixserverlib.FederationClient, extRoomsProvider types.ExternalPublicRoomsProvider, ) { r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() - authData := auth.Data{ - AccountDB: nil, - DeviceDB: deviceDB, - AppServices: nil, - } - r0mux.Handle("/directory/list/room/{roomID}", - internal.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -61,8 +53,8 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) // TODO: Add AS support r0mux.Handle("/directory/list/room/{roomID}", - internal.MakeAuthAPI("directory_list", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -70,7 +62,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/publicRooms", - internal.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { if extRoomsProvider != nil { return directory.GetPostPublicRoomsWithExternal(req, publicRoomsDB, fedClient, extRoomsProvider) } @@ -80,7 +72,7 @@ func Setup( // Federation - TODO: should this live here or in federation API? It's sure easier if it's here so here it is. publicAPIMux.Handle("/federation/v1/publicRooms", - internal.MakeExternalAPI("federation_public_rooms", func(req *http.Request) util.JSONResponse { + httputil.MakeExternalAPI("federation_public_rooms", func(req *http.Request) util.JSONResponse { return directory.GetPostPublicRooms(req, publicRoomsDB) }), ).Methods(http.MethodGet) diff --git a/publicroomsapi/storage/postgres/storage.go b/publicroomsapi/storage/postgres/storage.go index 5f0f551a7..36c6aec64 100644 --- a/publicroomsapi/storage/postgres/storage.go +++ b/publicroomsapi/storage/postgres/storage.go @@ -20,7 +20,7 @@ import ( "database/sql" "encoding/json" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" @@ -29,7 +29,7 @@ import ( // PublicRoomsServerDatabase represents a public rooms server database. type PublicRoomsServerDatabase struct { db *sql.DB - internal.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements statements publicRoomsStatements localServerName gomatrixserverlib.ServerName } @@ -37,7 +37,7 @@ type PublicRoomsServerDatabase struct { type attributeValue interface{} // NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties internal.DbProperties, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { var db *sql.DB var err error if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { @@ -138,33 +138,33 @@ func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent( case "m.room.aliases": return d.updateRoomAliases(ctx, event) case "m.room.canonical_alias": - var content internal.CanonicalAliasContent + var content eventutil.CanonicalAliasContent field := &(content.Alias) attrName := "canonical_alias" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.name": - var content internal.NameContent + var content eventutil.NameContent field := &(content.Name) attrName := "name" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.topic": - var content internal.TopicContent + var content eventutil.TopicContent field := &(content.Topic) attrName := "topic" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.avatar": - var content internal.AvatarContent + var content eventutil.AvatarContent field := &(content.URL) attrName := "avatar_url" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.history_visibility": - var content internal.HistoryVisibilityContent + var content eventutil.HistoryVisibilityContent field := &(content.HistoryVisibility) attrName := "world_readable" strForTrue := "world_readable" return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) case "m.room.guest_access": - var content internal.GuestAccessContent + var content eventutil.GuestAccessContent field := &(content.GuestAccess) attrName := "guest_can_join" strForTrue := "can_join" @@ -248,7 +248,7 @@ func (d *PublicRoomsServerDatabase) updateRoomAliases( if aliasesEvent.StateKey() == nil || *aliasesEvent.StateKey() != string(d.localServerName) { return nil // only store our own aliases } - var content internal.AliasesContent + var content eventutil.AliasesContent if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { return err } diff --git a/publicroomsapi/storage/sqlite3/storage.go b/publicroomsapi/storage/sqlite3/storage.go index f51051823..5c685d131 100644 --- a/publicroomsapi/storage/sqlite3/storage.go +++ b/publicroomsapi/storage/sqlite3/storage.go @@ -22,7 +22,7 @@ import ( _ "github.com/mattn/go-sqlite3" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" @@ -31,7 +31,7 @@ import ( // PublicRoomsServerDatabase represents a public rooms server database. type PublicRoomsServerDatabase struct { db *sql.DB - internal.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements statements publicRoomsStatements localServerName gomatrixserverlib.ServerName } @@ -46,7 +46,7 @@ func NewPublicRoomsServerDatabase(dataSourceName string, localServerName gomatri if err != nil { return nil, err } - if db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil); err != nil { + if db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } storage := PublicRoomsServerDatabase{ @@ -144,33 +144,33 @@ func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent( case "m.room.aliases": return d.updateRoomAliases(ctx, event) case "m.room.canonical_alias": - var content internal.CanonicalAliasContent + var content eventutil.CanonicalAliasContent field := &(content.Alias) attrName := "canonical_alias" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.name": - var content internal.NameContent + var content eventutil.NameContent field := &(content.Name) attrName := "name" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.topic": - var content internal.TopicContent + var content eventutil.TopicContent field := &(content.Topic) attrName := "topic" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.avatar": - var content internal.AvatarContent + var content eventutil.AvatarContent field := &(content.URL) attrName := "avatar_url" return d.updateStringAttribute(ctx, attrName, event, &content, field) case "m.room.history_visibility": - var content internal.HistoryVisibilityContent + var content eventutil.HistoryVisibilityContent field := &(content.HistoryVisibility) attrName := "world_readable" strForTrue := "world_readable" return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) case "m.room.guest_access": - var content internal.GuestAccessContent + var content eventutil.GuestAccessContent field := &(content.GuestAccess) attrName := "guest_can_join" strForTrue := "can_join" @@ -254,7 +254,7 @@ func (d *PublicRoomsServerDatabase) updateRoomAliases( if aliasesEvent.StateKey() == nil || *aliasesEvent.StateKey() != string(d.localServerName) { return nil // only store our own aliases } - var content internal.AliasesContent + var content eventutil.AliasesContent if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { return err } diff --git a/publicroomsapi/storage/storage.go b/publicroomsapi/storage/storage.go index 507b9fa6c..f66188040 100644 --- a/publicroomsapi/storage/storage.go +++ b/publicroomsapi/storage/storage.go @@ -19,7 +19,7 @@ package storage import ( "net/url" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/publicroomsapi/storage/postgres" "github.com/matrix-org/dendrite/publicroomsapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" @@ -29,7 +29,7 @@ const schemePostgres = "postgres" const schemeFile = "file" // NewPublicRoomsServerDatabase opens a database connection. -func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties internal.DbProperties, localServerName gomatrixserverlib.ServerName) (Database, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, localServerName gomatrixserverlib.ServerName) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties, localServerName) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index aefe55bcd..967f58baf 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -65,13 +65,6 @@ type RoomserverInternalAPI interface { response *QueryMembershipsForRoomResponse, ) error - // Query a list of invite event senders for a user in a room. - QueryInvitesForUser( - ctx context.Context, - request *QueryInvitesForUserRequest, - response *QueryInvitesForUserResponse, - ) error - // Query whether a server is allowed to see an event QueryServerAllowedToSeeEvent( ctx context.Context, @@ -96,10 +89,10 @@ type RoomserverInternalAPI interface { ) error // Query a given amount (or less) of events prior to a given set of events. - QueryBackfill( + PerformBackfill( ctx context.Context, - request *QueryBackfillRequest, - response *QueryBackfillResponse, + request *PerformBackfillRequest, + response *PerformBackfillResponse, ) error // Asks for the default room version as preferred by the server. diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go new file mode 100644 index 000000000..a478eeb9a --- /dev/null +++ b/roomserver/api/api_trace.go @@ -0,0 +1,218 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + + fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/util" +) + +// RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the +// complete request/response/error +type RoomserverInternalAPITrace struct { + Impl RoomserverInternalAPI +} + +func (t *RoomserverInternalAPITrace) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) { + t.Impl.SetFederationSenderAPI(fsAPI) +} + +func (t *RoomserverInternalAPITrace) InputRoomEvents( + ctx context.Context, + req *InputRoomEventsRequest, + res *InputRoomEventsResponse, +) error { + err := t.Impl.InputRoomEvents(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) PerformJoin( + ctx context.Context, + req *PerformJoinRequest, + res *PerformJoinResponse, +) error { + err := t.Impl.PerformJoin(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformJoin req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) PerformLeave( + ctx context.Context, + req *PerformLeaveRequest, + res *PerformLeaveResponse, +) error { + err := t.Impl.PerformLeave(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformLeave req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryLatestEventsAndState( + ctx context.Context, + req *QueryLatestEventsAndStateRequest, + res *QueryLatestEventsAndStateResponse, +) error { + err := t.Impl.QueryLatestEventsAndState(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryLatestEventsAndState req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryStateAfterEvents( + ctx context.Context, + req *QueryStateAfterEventsRequest, + res *QueryStateAfterEventsResponse, +) error { + err := t.Impl.QueryStateAfterEvents(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryStateAfterEvents req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryEventsByID( + ctx context.Context, + req *QueryEventsByIDRequest, + res *QueryEventsByIDResponse, +) error { + err := t.Impl.QueryEventsByID(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryEventsByID req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryMembershipForUser( + ctx context.Context, + req *QueryMembershipForUserRequest, + res *QueryMembershipForUserResponse, +) error { + err := t.Impl.QueryMembershipForUser(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryMembershipForUser req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryMembershipsForRoom( + ctx context.Context, + req *QueryMembershipsForRoomRequest, + res *QueryMembershipsForRoomResponse, +) error { + err := t.Impl.QueryMembershipsForRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryMembershipsForRoom req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryServerAllowedToSeeEvent( + ctx context.Context, + req *QueryServerAllowedToSeeEventRequest, + res *QueryServerAllowedToSeeEventResponse, +) error { + err := t.Impl.QueryServerAllowedToSeeEvent(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryServerAllowedToSeeEvent req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryMissingEvents( + ctx context.Context, + req *QueryMissingEventsRequest, + res *QueryMissingEventsResponse, +) error { + err := t.Impl.QueryMissingEvents(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryMissingEvents req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryStateAndAuthChain( + ctx context.Context, + req *QueryStateAndAuthChainRequest, + res *QueryStateAndAuthChainResponse, +) error { + err := t.Impl.QueryStateAndAuthChain(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryStateAndAuthChain req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) PerformBackfill( + ctx context.Context, + req *PerformBackfillRequest, + res *PerformBackfillResponse, +) error { + err := t.Impl.PerformBackfill(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformBackfill req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryRoomVersionCapabilities( + ctx context.Context, + req *QueryRoomVersionCapabilitiesRequest, + res *QueryRoomVersionCapabilitiesResponse, +) error { + err := t.Impl.QueryRoomVersionCapabilities(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryRoomVersionCapabilities req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) QueryRoomVersionForRoom( + ctx context.Context, + req *QueryRoomVersionForRoomRequest, + res *QueryRoomVersionForRoomResponse, +) error { + err := t.Impl.QueryRoomVersionForRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryRoomVersionForRoom req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) SetRoomAlias( + ctx context.Context, + req *SetRoomAliasRequest, + res *SetRoomAliasResponse, +) error { + err := t.Impl.SetRoomAlias(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("SetRoomAlias req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) GetRoomIDForAlias( + ctx context.Context, + req *GetRoomIDForAliasRequest, + res *GetRoomIDForAliasResponse, +) error { + err := t.Impl.GetRoomIDForAlias(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("GetRoomIDForAlias req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) GetAliasesForRoomID( + ctx context.Context, + req *GetAliasesForRoomIDRequest, + res *GetAliasesForRoomIDResponse, +) error { + err := t.Impl.GetAliasesForRoomID(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("GetAliasesForRoomID req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) GetCreatorIDForAlias( + ctx context.Context, + req *GetCreatorIDForAliasRequest, + res *GetCreatorIDForAliasResponse, +) error { + err := t.Impl.GetCreatorIDForAlias(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("GetCreatorIDForAlias req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *RoomserverInternalAPITrace) RemoveRoomAlias( + ctx context.Context, + req *RemoveRoomAliasRequest, + res *RemoveRoomAliasResponse, +) error { + err := t.Impl.RemoveRoomAlias(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("RemoveRoomAlias req=%+v res=%+v", js(req), js(res)) + return err +} + +func js(thing interface{}) string { + b, err := json.Marshal(thing) + if err != nil { + return fmt.Sprintf("Marshal error:%s", err) + } + return string(b) +} diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 92a468a96..2bbd97af8 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -63,6 +63,13 @@ type OutputNewRoomEvent struct { // Together with RemovesStateEventIDs this allows the receiver to keep an up to date // view of the current state of the room. AddsStateEventIDs []string `json:"adds_state_event_ids"` + // All extra newly added state events. This is only set if there are *extra* events + // other than `Event`. This can happen when forks get merged because state resolution + // may decide a bunch of state events on one branch are now valid, so they will be + // present in this list. This is useful when trying to maintain the current state of a room + // as to do so you need to include both these events and `Event`. + AddStateEvents []gomatrixserverlib.HeaderedEvent `json:"adds_state_events"` + // The state event IDs that were removed from the state of the room by this event. RemovesStateEventIDs []string `json:"removes_state_event_ids"` // The ID of the event that was output before this event. @@ -112,6 +119,26 @@ type OutputNewRoomEvent struct { TransactionID *TransactionID `json:"transaction_id"` } +// AddsState returns all added state events from this event. +// +// This function is needed because `AddStateEvents` will not include a copy of +// the original event to save space, so you cannot use that slice alone. +// Instead, use this function which will add the original event if it is present +// in `AddsStateEventIDs`. +func (ore *OutputNewRoomEvent) AddsState() []gomatrixserverlib.HeaderedEvent { + includeOutputEvent := false + for _, id := range ore.AddsStateEventIDs { + if id == ore.Event.EventID() { + includeOutputEvent = true + break + } + } + if !includeOutputEvent { + return ore.AddStateEvents + } + return append(ore.AddStateEvents, ore.Event) +} + // An OutputNewInviteEvent is written whenever an invite becomes active. // Invite events can be received outside of an existing room so have to be // tracked separately from the room events themselves. diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 1cf54144e..3e5cae1b6 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -2,6 +2,7 @@ package api import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) type PerformJoinRequest struct { @@ -22,3 +23,31 @@ type PerformLeaveRequest struct { type PerformLeaveResponse struct { } + +// PerformBackfillRequest is a request to PerformBackfill. +type PerformBackfillRequest struct { + // The room to backfill + RoomID string `json:"room_id"` + // A map of backwards extremity event ID to a list of its prev_event IDs. + BackwardsExtremities map[string][]string `json:"backwards_extremities"` + // The maximum number of events to retrieve. + Limit int `json:"limit"` + // The server interested in the events. + ServerName gomatrixserverlib.ServerName `json:"server_name"` +} + +// PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. +func (r *PerformBackfillRequest) PrevEventIDs() []string { + var prevEventIDs []string + for _, pes := range r.BackwardsExtremities { + prevEventIDs = append(prevEventIDs, pes...) + } + prevEventIDs = util.UniqueStrings(prevEventIDs) + return prevEventIDs +} + +// PerformBackfillResponse is a response to PerformBackfill. +type PerformBackfillResponse struct { + // Missing events, arbritrary order. + Events []gomatrixserverlib.HeaderedEvent `json:"events"` +} diff --git a/roomserver/api/query.go b/roomserver/api/query.go index dc005c77c..6586b1af3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -18,7 +18,6 @@ package api import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) // QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState @@ -34,8 +33,6 @@ type QueryLatestEventsAndStateRequest struct { // This is used when sending events to set the prev_events, auth_events and depth. // It is also used to tell whether the event is allowed by the event auth rules. type QueryLatestEventsAndStateResponse struct { - // Copy of the request for debugging. - QueryLatestEventsAndStateRequest // Does the room exist? // If the room doesn't exist this will be false and LatestEvents will be empty. RoomExists bool `json:"room_exists"` @@ -67,8 +64,6 @@ type QueryStateAfterEventsRequest struct { // QueryStateAfterEventsResponse is a response to QueryStateAfterEvents type QueryStateAfterEventsResponse struct { - // Copy of the request for debugging. - QueryStateAfterEventsRequest // Does the room exist on this roomserver? // If the room doesn't exist this will be false and StateEvents will be empty. RoomExists bool `json:"room_exists"` @@ -90,8 +85,6 @@ type QueryEventsByIDRequest struct { // QueryEventsByIDResponse is a response to QueryEventsByID type QueryEventsByIDResponse struct { - // Copy of the request for debugging. - QueryEventsByIDRequest // A list of events with the requested IDs. // If the roomserver does not have a copy of a requested event // then it will omit that event from the list. @@ -140,23 +133,6 @@ type QueryMembershipsForRoomResponse struct { HasBeenInRoom bool `json:"has_been_in_room"` } -// QueryInvitesForUserRequest is a request to QueryInvitesForUser -type QueryInvitesForUserRequest struct { - // The room ID to look up invites in. - RoomID string `json:"room_id"` - // The User ID to look up invites for. - TargetUserID string `json:"target_user_id"` -} - -// QueryInvitesForUserResponse is a response to QueryInvitesForUser -// This is used when accepting an invite or rejecting a invite to tell which -// remote matrix servers to contact. -type QueryInvitesForUserResponse struct { - // A list of matrix user IDs for each sender of an active invite targeting - // the requested user ID. - InviteSenderUserIDs []string `json:"invite_sender_user_ids"` -} - // QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent type QueryServerAllowedToSeeEventRequest struct { // The event ID to look up invites in. @@ -205,8 +181,6 @@ type QueryStateAndAuthChainRequest struct { // QueryStateAndAuthChainResponse is a response to QueryStateAndAuthChain type QueryStateAndAuthChainResponse struct { - // Copy of the request for debugging. - QueryStateAndAuthChainRequest // Does the room exist on this roomserver? // If the room doesn't exist this will be false and StateEvents will be empty. RoomExists bool `json:"room_exists"` @@ -221,34 +195,6 @@ type QueryStateAndAuthChainResponse struct { AuthChainEvents []gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"` } -// QueryBackfillRequest is a request to QueryBackfill. -type QueryBackfillRequest struct { - // The room to backfill - RoomID string `json:"room_id"` - // A map of backwards extremity event ID to a list of its prev_event IDs. - BackwardsExtremities map[string][]string `json:"backwards_extremities"` - // The maximum number of events to retrieve. - Limit int `json:"limit"` - // The server interested in the events. - ServerName gomatrixserverlib.ServerName `json:"server_name"` -} - -// PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. -func (r *QueryBackfillRequest) PrevEventIDs() []string { - var prevEventIDs []string - for _, pes := range r.BackwardsExtremities { - prevEventIDs = append(prevEventIDs, pes...) - } - prevEventIDs = util.UniqueStrings(prevEventIDs) - return prevEventIDs -} - -// QueryBackfillResponse is a response to QueryBackfill. -type QueryBackfillResponse struct { - // Missing events, arbritrary order. - Events []gomatrixserverlib.HeaderedEvent `json:"events"` -} - // QueryRoomVersionCapabilitiesRequest asks for the default room version type QueryRoomVersionCapabilitiesRequest struct{} diff --git a/roomserver/internal/input.go b/roomserver/internal/input.go index 932b4df46..e863af953 100644 --- a/roomserver/internal/input.go +++ b/roomserver/internal/input.go @@ -21,6 +21,7 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/roomserver/api" + log "github.com/sirupsen/logrus" fsAPI "github.com/matrix-org/dendrite/federationsender/api" ) @@ -40,6 +41,21 @@ func (r *RoomserverInternalAPI) WriteOutputEvents(roomID string, updates []api.O if err != nil { return err } + logger := log.WithFields(log.Fields{ + "room_id": roomID, + "type": updates[i].Type, + }) + if updates[i].NewRoomEvent != nil { + logger = logger.WithFields(log.Fields{ + "event_type": updates[i].NewRoomEvent.Event.Type(), + "event_id": updates[i].NewRoomEvent.Event.EventID(), + "adds_state": len(updates[i].NewRoomEvent.AddsStateEventIDs), + "removes_state": len(updates[i].NewRoomEvent.RemovesStateEventIDs), + "send_as_server": updates[i].NewRoomEvent.SendAsServer, + "sender": updates[i].NewRoomEvent.Event.Sender(), + }) + } + logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic) messages[i] = &sarama.ProducerMessage{ Topic: r.OutputRoomEventTopic, Key: sarama.StringEncoder(roomID), diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index 3a38636f8..4487aea02 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -21,7 +21,7 @@ import ( "errors" "fmt" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" @@ -177,7 +177,7 @@ func (r *RoomserverInternalAPI) processInviteEvent( } succeeded := false defer func() { - txerr := internal.EndTransaction(updater, &succeeded) + txerr := sqlutil.EndTransaction(updater, &succeeded) if err == nil && txerr != nil { err = txerr } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index aea85ca95..66316ac4f 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -19,8 +19,9 @@ package internal import ( "bytes" "context" + "fmt" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" @@ -59,7 +60,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( } succeeded := false defer func() { - txerr := internal.EndTransaction(updater, &succeeded) + txerr := sqlutil.EndTransaction(updater, &succeeded) if err == nil && txerr != nil { err = txerr } @@ -310,24 +311,11 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) TransactionID: u.transactionID, } - var stateEventNIDs []types.EventNID - for _, entry := range u.added { - stateEventNIDs = append(stateEventNIDs, entry.EventNID) - } - for _, entry := range u.removed { - stateEventNIDs = append(stateEventNIDs, entry.EventNID) - } - for _, entry := range u.stateBeforeEventRemoves { - stateEventNIDs = append(stateEventNIDs, entry.EventNID) - } - for _, entry := range u.stateBeforeEventAdds { - stateEventNIDs = append(stateEventNIDs, entry.EventNID) - } - stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] - eventIDMap, err := u.api.DB.EventIDs(u.ctx, stateEventNIDs) + eventIDMap, err := u.stateEventMap() if err != nil { return nil, err } + for _, entry := range u.added { ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID]) } @@ -342,12 +330,60 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) } ore.SendAsServer = u.sendAsServer + // include extra state events if they were added as nearly every downstream component will care about it + // and we'd rather not have them all hit QueryEventsByID at the same time! + if len(ore.AddsStateEventIDs) > 0 { + ore.AddStateEvents, err = u.extraEventsForIDs(roomVersion, ore.AddsStateEventIDs) + if err != nil { + return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) + } + } + return &api.OutputEvent{ Type: api.OutputTypeNewRoomEvent, NewRoomEvent: &ore, }, nil } +// extraEventsForIDs returns the full events for the event IDs given, but does not include the current event being +// updated. +func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { + var extraEventIDs []string + for _, e := range eventIDs { + if e == u.event.EventID() { + continue + } + extraEventIDs = append(extraEventIDs, e) + } + if len(extraEventIDs) == 0 { + return nil, nil + } + extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs) + if err != nil { + return nil, err + } + var h []gomatrixserverlib.HeaderedEvent + for _, e := range extraEvents { + h = append(h, e.Headered(roomVersion)) + } + return h, nil +} + +// retrieve an event nid -> event ID map for all events that need updating +func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) { + var stateEventNIDs []types.EventNID + var allStateEntries []types.StateEntry + allStateEntries = append(allStateEntries, u.added...) + allStateEntries = append(allStateEntries, u.removed...) + allStateEntries = append(allStateEntries, u.stateBeforeEventRemoves...) + allStateEntries = append(allStateEntries, u.stateBeforeEventAdds...) + for _, entry := range allStateEntries { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] + return u.api.DB.EventIDs(u.ctx, stateEventNIDs) +} + type eventNIDSorter []types.EventNID func (s eventNIDSorter) Len() int { return len(s) } diff --git a/roomserver/internal/query_backfill.go b/roomserver/internal/perform_backfill.go similarity index 100% rename from roomserver/internal/query_backfill.go rename to roomserver/internal/perform_backfill.go diff --git a/roomserver/internal/perform_join.go b/roomserver/internal/perform_join.go index 75dfdd19c..1c951bb15 100644 --- a/roomserver/internal/perform_join.go +++ b/roomserver/internal/perform_join.go @@ -7,7 +7,7 @@ import ( "time" fsAPI "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -159,7 +159,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. buildRes := api.QueryLatestEventsAndStateResponse{} - event, err := internal.BuildEvent( + event, err := eventutil.BuildEvent( ctx, // the request context &eb, // the template join event r.Cfg, // the server configuration @@ -202,7 +202,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( } } - case internal.ErrRoomNoExists: + case eventutil.ErrRoomNoExists: // The room doesn't exist. First of all check if the room is a local // room. If it is then there's nothing more to do - the room just // hasn't been created yet. diff --git a/roomserver/internal/perform_leave.go b/roomserver/internal/perform_leave.go index b60c2e99f..880c8b203 100644 --- a/roomserver/internal/perform_leave.go +++ b/roomserver/internal/perform_leave.go @@ -7,7 +7,7 @@ import ( "time" fsAPI "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -97,7 +97,7 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID( // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. buildRes := api.QueryLatestEventsAndStateResponse{} - event, err := internal.BuildEvent( + event, err := eventutil.BuildEvent( ctx, // the request context &eb, // the template leave event r.Cfg, // the server configuration @@ -106,7 +106,7 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID( &buildRes, // the query response ) if err != nil { - return fmt.Errorf("internal.BuildEvent: %w", err) + return fmt.Errorf("eventutil.BuildEvent: %w", err) } // Give our leave event to the roomserver input stream. The diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go index 9fb67e7eb..4fc8e4c25 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query.go @@ -45,7 +45,6 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState( roomState := state.NewStateResolution(r.DB) - response.QueryLatestEventsAndStateRequest = *request roomNID, err := r.DB.RoomNIDExcludingStubs(ctx, request.RoomID) if err != nil { return err @@ -105,7 +104,6 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( roomState := state.NewStateResolution(r.DB) - response.QueryStateAfterEventsRequest = *request roomNID, err := r.DB.RoomNIDExcludingStubs(ctx, request.RoomID) if err != nil { return err @@ -153,8 +151,6 @@ func (r *RoomserverInternalAPI) QueryEventsByID( request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { - response.QueryEventsByIDRequest = *request - eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) if err != nil { return err @@ -353,40 +349,6 @@ func getMembershipsAtState( return events, nil } -// QueryInvitesForUser implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryInvitesForUser( - ctx context.Context, - request *api.QueryInvitesForUserRequest, - response *api.QueryInvitesForUserResponse, -) error { - roomNID, err := r.DB.RoomNID(ctx, request.RoomID) - if err != nil { - return err - } - - targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID}) - if err != nil { - return err - } - targetUserNID := targetUserNIDs[request.TargetUserID] - - senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID) - if err != nil { - return err - } - - senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs) - if err != nil { - return err - } - - for _, senderUserID := range senderUserIDs { - response.InviteSenderUserIDs = append(response.InviteSenderUserIDs, senderUserID) - } - - return nil -} - // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( ctx context.Context, @@ -475,11 +437,11 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( return err } -// QueryBackfill implements api.RoomServerQueryAPI -func (r *RoomserverInternalAPI) QueryBackfill( +// PerformBackfill implements api.RoomServerQueryAPI +func (r *RoomserverInternalAPI) PerformBackfill( ctx context.Context, - request *api.QueryBackfillRequest, - response *api.QueryBackfillResponse, + request *api.PerformBackfillRequest, + response *api.PerformBackfillResponse, ) error { // if we are requesting the backfill then we need to do a federation hit // TODO: we could be more sensible and fetch as many events we already have then request the rest @@ -523,7 +485,7 @@ func (r *RoomserverInternalAPI) QueryBackfill( return err } -func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.QueryBackfillRequest, res *api.QueryBackfillResponse) error { +func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error { roomVer, err := r.DB.GetRoomVersionForRoom(ctx, req.RoomID) if err != nil { return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err) @@ -681,7 +643,7 @@ func (r *RoomserverInternalAPI) scanEventTree( var pre string // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be) - // Currently, callers like QueryBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing + // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in // duplicate events being sent in response to /backfill requests. initialIgnoreList := make(map[string]bool, len(visited)) @@ -768,7 +730,6 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain( request *api.QueryStateAndAuthChainRequest, response *api.QueryStateAndAuthChainResponse, ) error { - response.QueryStateAndAuthChainRequest = *request roomNID, err := r.DB.RoomNIDExcludingStubs(ctx, request.RoomID) if err != nil { return err diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 5cc2537e3..e41adb993 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -7,7 +7,7 @@ import ( fsInputAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/caching" - internalHTTP "github.com/matrix-org/dendrite/internal/http" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/opentracing/opentracing-go" ) @@ -24,8 +24,9 @@ const ( RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" // Perform operations - RoomserverPerformJoinPath = "/roomserver/performJoin" - RoomserverPerformLeavePath = "/roomserver/performLeave" + RoomserverPerformJoinPath = "/roomserver/performJoin" + RoomserverPerformLeavePath = "/roomserver/performLeave" + RoomserverPerformBackfillPath = "/roomserver/performBackfill" // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" @@ -33,11 +34,9 @@ const ( RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" - RoomserverQueryInvitesForUserPath = "/roomserver/queryInvitesForUser" RoomserverQueryServerAllowedToSeeEventPath = "/roomserver/queryServerAllowedToSeeEvent" RoomserverQueryMissingEventsPath = "/roomserver/queryMissingEvents" RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain" - RoomserverQueryBackfillPath = "/roomserver/queryBackfill" RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities" RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom" ) @@ -79,7 +78,7 @@ func (h *httpRoomserverInternalAPI) SetRoomAlias( defer span.Finish() apiURL := h.roomserverURL + RoomserverSetRoomAliasPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // GetRoomIDForAlias implements RoomserverAliasAPI @@ -92,7 +91,7 @@ func (h *httpRoomserverInternalAPI) GetRoomIDForAlias( defer span.Finish() apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // GetAliasesForRoomID implements RoomserverAliasAPI @@ -105,7 +104,7 @@ func (h *httpRoomserverInternalAPI) GetAliasesForRoomID( defer span.Finish() apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // GetCreatorIDForAlias implements RoomserverAliasAPI @@ -118,7 +117,7 @@ func (h *httpRoomserverInternalAPI) GetCreatorIDForAlias( defer span.Finish() apiURL := h.roomserverURL + RoomserverGetCreatorIDForAliasPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // RemoveRoomAlias implements RoomserverAliasAPI @@ -131,7 +130,7 @@ func (h *httpRoomserverInternalAPI) RemoveRoomAlias( defer span.Finish() apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // InputRoomEvents implements RoomserverInputAPI @@ -144,7 +143,7 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents( defer span.Finish() apiURL := h.roomserverURL + RoomserverInputRoomEventsPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } func (h *httpRoomserverInternalAPI) PerformJoin( @@ -156,7 +155,7 @@ func (h *httpRoomserverInternalAPI) PerformJoin( defer span.Finish() apiURL := h.roomserverURL + RoomserverPerformJoinPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } func (h *httpRoomserverInternalAPI) PerformLeave( @@ -168,7 +167,7 @@ func (h *httpRoomserverInternalAPI) PerformLeave( defer span.Finish() apiURL := h.roomserverURL + RoomserverPerformLeavePath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryLatestEventsAndState implements RoomserverQueryAPI @@ -181,7 +180,7 @@ func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryStateAfterEvents implements RoomserverQueryAPI @@ -194,7 +193,7 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryEventsByID implements RoomserverQueryAPI @@ -207,7 +206,7 @@ func (h *httpRoomserverInternalAPI) QueryEventsByID( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryMembershipForUser implements RoomserverQueryAPI @@ -220,7 +219,7 @@ func (h *httpRoomserverInternalAPI) QueryMembershipForUser( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryMembershipsForRoom implements RoomserverQueryAPI @@ -233,20 +232,7 @@ func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -// QueryInvitesForUser implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryInvitesForUser( - ctx context.Context, - request *api.QueryInvitesForUserRequest, - response *api.QueryInvitesForUserResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryInvitesForUser") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryInvitesForUserPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryServerAllowedToSeeEvent implements RoomserverQueryAPI @@ -259,7 +245,7 @@ func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryMissingEvents implements RoomServerQueryAPI @@ -272,7 +258,7 @@ func (h *httpRoomserverInternalAPI) QueryMissingEvents( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryStateAndAuthChain implements RoomserverQueryAPI @@ -285,20 +271,20 @@ func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } -// QueryBackfill implements RoomServerQueryAPI -func (h *httpRoomserverInternalAPI) QueryBackfill( +// PerformBackfill implements RoomServerQueryAPI +func (h *httpRoomserverInternalAPI) PerformBackfill( ctx context.Context, - request *api.QueryBackfillRequest, - response *api.QueryBackfillResponse, + request *api.PerformBackfillRequest, + response *api.PerformBackfillResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBackfill") + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBackfill") defer span.Finish() - apiURL := h.roomserverURL + RoomserverQueryBackfillPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + apiURL := h.roomserverURL + RoomserverPerformBackfillPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryRoomVersionCapabilities implements RoomServerQueryAPI @@ -311,7 +297,7 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } // QueryRoomVersionForRoom implements RoomServerQueryAPI @@ -329,7 +315,7 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( defer span.Finish() apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath - err := internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) if err == nil { h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion) } diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 9a58a30b0..822acd15b 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/util" ) @@ -14,7 +14,7 @@ import ( // nolint: gocyclo func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle(RoomserverInputRoomEventsPath, - internal.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse { var request api.InputRoomEventsRequest var response api.InputRoomEventsResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -27,7 +27,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { }), ) internalAPIMux.Handle(RoomserverPerformJoinPath, - internal.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse { var request api.PerformJoinRequest var response api.PerformJoinResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -40,7 +40,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { }), ) internalAPIMux.Handle(RoomserverPerformLeavePath, - internal.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse { var request api.PerformLeaveRequest var response api.PerformLeaveResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -54,7 +54,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverQueryLatestEventsAndStatePath, - internal.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { var request api.QueryLatestEventsAndStateRequest var response api.QueryLatestEventsAndStateResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -68,7 +68,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverQueryStateAfterEventsPath, - internal.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse { var request api.QueryStateAfterEventsRequest var response api.QueryStateAfterEventsResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -82,7 +82,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverQueryEventsByIDPath, - internal.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { var request api.QueryEventsByIDRequest var response api.QueryEventsByIDResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -96,7 +96,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverQueryMembershipForUserPath, - internal.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse { var request api.QueryMembershipForUserRequest var response api.QueryMembershipForUserResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -110,7 +110,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverQueryMembershipsForRoomPath, - internal.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { var request api.QueryMembershipsForRoomRequest var response api.QueryMembershipsForRoomResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -122,23 +122,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle( - RoomserverQueryInvitesForUserPath, - internal.MakeInternalAPI("queryInvitesForUser", func(req *http.Request) util.JSONResponse { - var request api.QueryInvitesForUserRequest - var response api.QueryInvitesForUserResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryInvitesForUser(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) internalAPIMux.Handle( RoomserverQueryServerAllowedToSeeEventPath, - internal.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { var request api.QueryServerAllowedToSeeEventRequest var response api.QueryServerAllowedToSeeEventResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -152,7 +138,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverQueryMissingEventsPath, - internal.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse { var request api.QueryMissingEventsRequest var response api.QueryMissingEventsResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -166,7 +152,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverQueryStateAndAuthChainPath, - internal.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse { var request api.QueryStateAndAuthChainRequest var response api.QueryStateAndAuthChainResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -179,14 +165,14 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { }), ) internalAPIMux.Handle( - RoomserverQueryBackfillPath, - internal.MakeInternalAPI("QueryBackfill", func(req *http.Request) util.JSONResponse { - var request api.QueryBackfillRequest - var response api.QueryBackfillResponse + RoomserverPerformBackfillPath, + httputil.MakeInternalAPI("PerformBackfill", func(req *http.Request) util.JSONResponse { + var request api.PerformBackfillRequest + var response api.PerformBackfillResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.ErrorResponse(err) } - if err := r.QueryBackfill(req.Context(), &request, &response); err != nil { + if err := r.PerformBackfill(req.Context(), &request, &response); err != nil { return util.ErrorResponse(err) } return util.JSONResponse{Code: http.StatusOK, JSON: &response} @@ -194,7 +180,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverQueryRoomVersionCapabilitiesPath, - internal.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { var request api.QueryRoomVersionCapabilitiesRequest var response api.QueryRoomVersionCapabilitiesResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -208,7 +194,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverQueryRoomVersionForRoomPath, - internal.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse { var request api.QueryRoomVersionForRoomRequest var response api.QueryRoomVersionForRoomResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -222,7 +208,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverSetRoomAliasPath, - internal.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse { var request api.SetRoomAliasRequest var response api.SetRoomAliasResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -236,7 +222,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverGetRoomIDForAliasPath, - internal.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse { var request api.GetRoomIDForAliasRequest var response api.GetRoomIDForAliasResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -250,7 +236,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverGetCreatorIDForAliasPath, - internal.MakeInternalAPI("GetCreatorIDForAlias", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("GetCreatorIDForAlias", func(req *http.Request) util.JSONResponse { var request api.GetCreatorIDForAliasRequest var response api.GetCreatorIDForAliasResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -264,7 +250,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverGetAliasesForRoomIDPath, - internal.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { var request api.GetAliasesForRoomIDRequest var response api.GetAliasesForRoomIDResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -278,7 +264,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { ) internalAPIMux.Handle( RoomserverRemoveRoomAliasPath, - internal.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { var request api.RemoveRoomAliasRequest var response api.RemoveRoomAliasResponse if err := json.NewDecoder(req.Body).Decode(&request); err != nil { diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index a9db22d7a..427d5ff36 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/roomserver/internal" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/sirupsen/logrus" @@ -35,7 +35,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *basecomponent.BaseDendrite, + base *setup.BaseDendrite, keyRing gomatrixserverlib.JSONVerifier, fedClient *gomatrixserverlib.FederationClient, ) api.RoomserverInternalAPI { diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 3bce9f086..7df175954 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index b7e2cb6b7..500ff20e4 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -94,7 +95,7 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := internal.TxStmt(txn, s.insertEventStateKeyNIDStmt) + stmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } @@ -103,7 +104,7 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := internal.TxStmt(txn, s.selectEventStateKeyNIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index f8a7bff1b..037d98fe7 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -19,9 +19,8 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" - "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index d39e9511a..bdbf5e7cb 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -284,13 +285,13 @@ func (s *eventStatements) UpdateEventState( func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { - stmt := internal.TxStmt(txn, s.selectEventSentToOutputStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt) err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) return } func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - stmt := internal.TxStmt(txn, s.updateEventSentToOutputStmt) + stmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := stmt.ExecContext(ctx, int64(eventNID)) return err } @@ -298,7 +299,7 @@ func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql. func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { - stmt := internal.TxStmt(txn, s.selectEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventIDStmt) err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) return } @@ -306,7 +307,7 @@ func (s *eventStatements) SelectEventID( func (s *eventStatements) BulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { - stmt := internal.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index 55a5bc7d7..048a094dc 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -101,7 +102,7 @@ func (s *inviteStatements) InsertInviteEvent( targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { - result, err := internal.TxStmt(txn, s.insertInviteEventStmt).ExecContext( + result, err := sqlutil.TxStmt(txn, s.insertInviteEventStmt).ExecContext( ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ) if err != nil { @@ -118,7 +119,7 @@ func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { - stmt := internal.TxStmt(txn, s.updateInviteRetiredStmt) + stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) if err != nil { return nil, err diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index c635ce28a..13cef638f 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -133,7 +134,7 @@ func (s *membershipStatements) InsertMembership( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { - stmt := internal.TxStmt(txn, s.insertMembershipStmt) + stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) return err } @@ -142,7 +143,7 @@ func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership tables.MembershipState, err error) { - err = internal.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( + err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership) return @@ -216,7 +217,7 @@ func (s *membershipStatements) UpdateMembership( senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { - _, err := internal.TxStmt(txn, s.updateMembershipStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, ) return err diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index ff0306967..1a4ba6732 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -19,7 +19,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -85,7 +85,7 @@ func (s *previousEventStatements) InsertPreviousEvent( previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - stmt := internal.TxStmt(txn, s.insertPreviousEventStmt) + stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err := stmt.ExecContext( ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), ) @@ -98,6 +98,6 @@ func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { var ok int64 - stmt := internal.TxStmt(txn, s.selectPreviousEventExistsStmt) + stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) } diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 5a353ed1c..8e00cfdb8 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -21,7 +21,7 @@ import ( "errors" "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -106,7 +106,7 @@ func (s *roomStatements) InsertRoomNID( roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { var roomNID int64 - stmt := internal.TxStmt(txn, s.insertRoomNIDStmt) + stmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) err := stmt.QueryRowContext(ctx, roomID, roomVersion).Scan(&roomNID) return types.RoomNID(roomNID), err } @@ -115,7 +115,7 @@ func (s *roomStatements) SelectRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { var roomNID int64 - stmt := internal.TxStmt(txn, s.selectRoomNIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } @@ -143,7 +143,7 @@ func (s *roomStatements) SelectLatestEventsNIDsForUpdate( var nids pq.Int64Array var lastEventSentNID int64 var stateSnapshotNID int64 - stmt := internal.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) if err != nil { return nil, 0, 0, err @@ -163,7 +163,7 @@ func (s *roomStatements) UpdateLatestEventNIDs( lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - stmt := internal.TxStmt(txn, s.updateLatestEventNIDsStmt) + stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) _, err := stmt.ExecContext( ctx, roomNID, @@ -178,7 +178,7 @@ func (s *roomStatements) SelectRoomVersionForRoomID( ctx context.Context, txn *sql.Tx, roomID string, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := internal.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index dbe21d989..d618686f7 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -21,9 +21,8 @@ import ( "fmt" "sort" - "github.com/matrix-org/dendrite/internal" - "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 971f2b9e6..d76ee0a92 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -18,7 +18,6 @@ package postgres import ( "database/sql" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" // Import the postgres database driver. @@ -33,7 +32,7 @@ type Database struct { // Open a postgres database. // nolint: gocyclo -func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, error) { +func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { var d Database var db *sql.DB var err error diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index bb5b51562..2751cc557 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -5,7 +5,7 @@ import ( "database/sql" "encoding/json" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -68,7 +68,7 @@ func (d *Database) AddState( stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { - err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { if len(state) > 0 { var stateBlockNID types.StateBlockNID stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) @@ -158,7 +158,7 @@ func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (ro func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, ) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { - err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { var eventNIDs []types.EventNID eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, txn, roomNID) if err != nil { @@ -337,7 +337,7 @@ func (d *Database) StoreEvent( err error ) - err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { if txnAndSessionID != nil { if err = d.TransactionsTable.InsertTransaction( ctx, txn, txnAndSessionID.TransactionID, diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 7ff4c1d4d..da0c448dc 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -69,7 +70,7 @@ func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - _, err := internal.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) return err } @@ -80,7 +81,7 @@ func (s *eventJSONStatements) BulkSelectEventJSON( for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", internal.QueryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil { diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 2d2c04d26..cbea8428c 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -90,7 +91,7 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID( var eventStateKeyNID int64 var err error var res sql.Result - insertStmt := internal.TxStmt(txn, s.insertEventStateKeyNIDStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { eventStateKeyNID, err = res.LastInsertId() } @@ -101,7 +102,7 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := internal.TxStmt(txn, s.selectEventStateKeyNIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } @@ -113,7 +114,7 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( for k, v := range eventStateKeys { iEventStateKeys[k] = v } - selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", internal.QueryVariadic(len(eventStateKeys)), 1) + selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) if err != nil { @@ -139,7 +140,7 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKey( for k, v := range eventStateKeyNIDs { iEventStateKeyNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", internal.QueryVariadic(len(eventStateKeyNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) if err != nil { diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 6d229bbd7..c9a461f99 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -104,8 +105,8 @@ func (s *eventTypeStatements) InsertEventTypeNID( ) (types.EventTypeNID, error) { var eventTypeNID int64 var err error - insertStmt := internal.TxStmt(tx, s.insertEventTypeNIDStmt) - resultStmt := internal.TxStmt(tx, s.insertEventTypeNIDResultStmt) + insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt) + resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt) if _, err = insertStmt.ExecContext(ctx, eventType); err == nil { err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) } @@ -116,7 +117,7 @@ func (s *eventTypeStatements) SelectEventTypeNID( ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - selectStmt := internal.TxStmt(tx, s.selectEventTypeNIDStmt) + selectStmt := sqlutil.TxStmt(tx, s.selectEventTypeNIDStmt) err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } @@ -129,7 +130,7 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( for k, v := range eventTypes { iEventTypes[k] = v } - selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", internal.QueryVariadic(len(iEventTypes)), 1) + selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventTypes)), 1) selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 491399ce7..d66db4694 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -150,7 +151,7 @@ func (s *eventStatements) InsertEvent( depth int64, ) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID - insertStmt := internal.TxStmt(txn, s.insertEventStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) result, err := insertStmt.ExecContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, @@ -171,7 +172,7 @@ func (s *eventStatements) SelectEvent( ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - selectStmt := internal.TxStmt(txn, s.selectEventStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectEventStmt) err := selectStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } @@ -186,7 +187,7 @@ func (s *eventStatements) BulkSelectStateEventByID( for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) + selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err @@ -238,7 +239,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) + selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err @@ -285,7 +286,7 @@ func (s *eventStatements) UpdateEventState( func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { - selectStmt := internal.TxStmt(txn, s.selectEventSentToOutputStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt) err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) //err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) if err != nil { @@ -294,7 +295,7 @@ func (s *eventStatements) SelectEventSentToOutput( } func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - updateStmt := internal.TxStmt(txn, s.updateEventSentToOutputStmt) + updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := updateStmt.ExecContext(ctx, int64(eventNID)) //_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID)) return err @@ -303,7 +304,7 @@ func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql. func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { - selectStmt := internal.TxStmt(txn, s.selectEventIDStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectEventIDStmt) err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) return } @@ -316,7 +317,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", internal.QueryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) ////////////// rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) @@ -362,14 +363,14 @@ func (s *eventStatements) BulkSelectEventReference( for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", internal.QueryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - selectStmt := internal.TxStmt(txn, selectPrep) + selectStmt := sqlutil.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) if err != nil { return nil, err @@ -396,7 +397,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", internal.QueryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err @@ -432,7 +433,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err @@ -461,7 +462,7 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, for i, v := range eventNIDs { iEventIDs[i] = v } - sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) + sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) err := txn.QueryRowContext(ctx, sqlStr, iEventIDs...).Scan(&result) if err != nil { return 0, err diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 7dcc2dc0b..21745d1b0 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -89,7 +90,7 @@ func (s *inviteStatements) InsertInviteEvent( targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { - stmt := internal.TxStmt(txn, s.insertInviteEventStmt) + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) result, err := stmt.ExecContext( ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ) @@ -108,7 +109,7 @@ func (s *inviteStatements) UpdateInviteRetired( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { // gather all the event IDs we will retire - stmt := internal.TxStmt(txn, s.selectInvitesAboutToRetireStmt) + stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) if err != nil { return nil, err @@ -123,7 +124,7 @@ func (s *inviteStatements) UpdateInviteRetired( } // now retire the invites - stmt = internal.TxStmt(txn, s.updateInviteRetiredStmt) + stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) return } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 5f3076c5a..6f0d763e7 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -109,7 +110,7 @@ func (s *membershipStatements) InsertMembership( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { - stmt := internal.TxStmt(txn, s.insertMembershipStmt) + stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) return err } @@ -118,7 +119,7 @@ func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership tables.MembershipState, err error) { - stmt := internal.TxStmt(txn, s.selectMembershipForUpdateStmt) + stmt := sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt) err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership) @@ -193,7 +194,7 @@ func (s *membershipStatements) UpdateMembership( senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { - stmt := internal.TxStmt(txn, s.updateMembershipStmt) + stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) _, err := stmt.ExecContext( ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, ) diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 2b187bba3..549aecfb7 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -19,7 +19,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -77,7 +77,7 @@ func (s *previousEventStatements) InsertPreviousEvent( previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - stmt := internal.TxStmt(txn, s.insertPreviousEventStmt) + stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err := stmt.ExecContext( ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), ) @@ -90,6 +90,6 @@ func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { var ok int64 - stmt := internal.TxStmt(txn, s.selectPreviousEventExistsStmt) + stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index d22fceb3b..ab695c5d2 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -21,7 +21,7 @@ import ( "encoding/json" "errors" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -95,7 +95,7 @@ func (s *roomStatements) InsertRoomNID( roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { var err error - insertStmt := internal.TxStmt(txn, s.insertRoomNIDStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil { return s.SelectRoomNID(ctx, txn, roomID) } else { @@ -107,7 +107,7 @@ func (s *roomStatements) SelectRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { var roomNID int64 - stmt := internal.TxStmt(txn, s.selectRoomNIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } @@ -118,7 +118,7 @@ func (s *roomStatements) SelectLatestEventNIDs( var eventNIDs []types.EventNID var nidsJSON string var stateSnapshotNID int64 - stmt := internal.TxStmt(txn, s.selectLatestEventNIDsStmt) + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &stateSnapshotNID) if err != nil { return nil, 0, err @@ -136,7 +136,7 @@ func (s *roomStatements) SelectLatestEventsNIDsForUpdate( var nidsJSON string var lastEventSentNID int64 var stateSnapshotNID int64 - stmt := internal.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &lastEventSentNID, &stateSnapshotNID) if err != nil { return nil, 0, 0, err @@ -155,7 +155,7 @@ func (s *roomStatements) UpdateLatestEventNIDs( lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - stmt := internal.TxStmt(txn, s.updateLatestEventNIDsStmt) + stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) _, err := stmt.ExecContext( ctx, eventNIDsAsArray(eventNIDs), @@ -170,7 +170,7 @@ func (s *roomStatements) SelectRoomVersionForRoomID( ctx context.Context, txn *sql.Tx, roomID string, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := internal.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 399fdbf63..c058c783a 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -130,7 +131,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( for k, v := range stateBlockNIDs { nids[k] = v } - selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", internal.QueryVariadic(len(nids)), 1) + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err @@ -186,9 +187,9 @@ func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", internal.QueryVariadic(len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($2)", internal.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($3)", internal.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) + sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) var params []interface{} for _, val := range stateBlockNIDs { diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 5de3f639a..d077b6171 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -92,7 +93,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( for k, v := range stateNIDs { nids[k] = v } - selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", internal.QueryVariadic(len(nids)), 1) + selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 5596b5e3c..8e9352192 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -20,8 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -52,7 +50,7 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } - if d.db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil); err != nil { + if d.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } //d.db.Exec("PRAGMA journal_mode=WAL;") diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 7b489ac6e..1e8de1ca8 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -19,7 +19,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -68,7 +68,7 @@ func (s *transactionStatements) InsertTransaction( userID string, eventID string, ) (err error) { - stmt := internal.TxStmt(txn, s.insertTransactionStmt) + stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) _, err = stmt.ExecContext( ctx, transactionID, sessionID, userID, eventID, ) diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 071e2cf1d..d7367e4c7 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -19,13 +19,13 @@ package storage import ( "net/url" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" ) // Open opens a database connection. -func Open(dataSourceName string, dbProperties internal.DbProperties) (Database, error) { +func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return postgres.Open(dataSourceName, dbProperties) diff --git a/roomserver/storage/storage_wasm.go b/roomserver/storage/storage_wasm.go index ef30ca593..78405b20e 100644 --- a/roomserver/storage/storage_wasm.go +++ b/roomserver/storage/storage_wasm.go @@ -18,14 +18,14 @@ import ( "fmt" "net/url" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" ) // NewPublicRoomsServerDatabase opens a database connection. func Open( dataSourceName string, - dbProperties internal.DbProperties, // nolint:unparam + dbProperties sqlutil.DbProperties, // nolint:unparam ) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 87a4f935c..241e1e15d 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -16,7 +16,7 @@ package types import ( - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -174,7 +174,7 @@ type RoomRecentEventsUpdater interface { // It will share the same transaction as this updater. MembershipUpdater(targetUserNID EventStateKeyNID, isTargetLocalUser bool) (MembershipUpdater, error) // Implements Transaction so it can be committed or rolledback - internal.Transaction + sqlutil.Transaction } // A MembershipUpdater is used to update the membership of a user in a room. @@ -199,7 +199,7 @@ type MembershipUpdater interface { // Returns a list of invite event IDs that this state change retired. SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) // Implements Transaction so it can be committed or rolledback. - internal.Transaction + sqlutil.Transaction } // A MissingEventError is an error that happened because the roomserver was diff --git a/serverkeyapi/internal/api.go b/serverkeyapi/internal/api.go index 92d6a70b2..02028c60e 100644 --- a/serverkeyapi/internal/api.go +++ b/serverkeyapi/internal/api.go @@ -2,16 +2,23 @@ package internal import ( "context" + "crypto/ed25519" "fmt" "time" "github.com/matrix-org/dendrite/serverkeyapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) type ServerKeyAPI struct { api.ServerKeyInternalAPI + ServerName gomatrixserverlib.ServerName + ServerPublicKey ed25519.PublicKey + ServerKeyID gomatrixserverlib.KeyID + ServerKeyValidity time.Duration + OurKeyRing gomatrixserverlib.KeyRing FedClient *gomatrixserverlib.FederationClient } @@ -22,7 +29,7 @@ func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing { // and keeping the cache up-to-date. return &gomatrixserverlib.KeyRing{ KeyDatabase: s, - KeyFetchers: []gomatrixserverlib.KeyFetcher{s}, + KeyFetchers: []gomatrixserverlib.KeyFetcher{}, } } @@ -33,6 +40,7 @@ func (s *ServerKeyAPI) StoreKeys( // Run in a background context - we don't want to stop this work just // because the caller gives up waiting. ctx := context.Background() + // Store any keys that we were given in our database. return s.OurKeyRing.KeyDatabase.StoreKeys(ctx, results) } @@ -44,45 +52,57 @@ func (s *ServerKeyAPI) FetchKeys( // Run in a background context - we don't want to stop this work just // because the caller gives up waiting. ctx := context.Background() + now := gomatrixserverlib.AsTimestamp(time.Now()) results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - // First consult our local database and see if we have the requested + origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{} + for k, v := range requests { + origRequests[k] = v + } + + // First, check if any of these key checks are for our own keys. If + // they are then we will satisfy them directly. + s.handleLocalKeys(ctx, requests, results) + + // Then consult our local database and see if we have the requested // keys. These might come from a cache, depending on the database // implementation used. - now := gomatrixserverlib.AsTimestamp(time.Now()) - if dbResults, err := s.OurKeyRing.KeyDatabase.FetchKeys(ctx, requests); err == nil { - // We successfully got some keys. Add them to the results and - // remove them from the request list. - for req, res := range dbResults { - if now > res.ValidUntilTS && res.ExpiredTS == gomatrixserverlib.PublicKeyNotExpired { - continue - } - results[req] = res - delete(requests, req) - } + if err := s.handleDatabaseKeys(ctx, now, requests, results); err != nil { + return nil, err } + // For any key requests that we still have outstanding, next try to // fetch them directly. We'll go through each of the key fetchers to - // ask for the remaining keys. + // ask for the remaining keys for _, fetcher := range s.OurKeyRing.KeyFetchers { + // If there are no more keys to look up then stop. if len(requests) == 0 { break } - if fetcherResults, err := fetcher.FetchKeys(ctx, requests); err == nil { - // We successfully got some keys. Add them to the results and - // remove them from the request list. - for req, res := range fetcherResults { - results[req] = res - delete(requests, req) - } - if err = s.OurKeyRing.KeyDatabase.StoreKeys(ctx, fetcherResults); err != nil { - return nil, fmt.Errorf("server key API failed to store retrieved keys: %w", err) - } + + // Ask the fetcher to look up our keys. + if err := s.handleFetcherKeys(ctx, now, fetcher, requests, results); err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + }).Errorf("Failed to retrieve %d key(s)", len(requests)) + continue } } - // If we failed to fetch any keys then we should report an error. - if len(requests) > 0 { - return results, fmt.Errorf("server key API failed to fetch %d keys", len(requests)) + + // Check that we've actually satisfied all of the key requests that we + // were given. We should report an error if we didn't. + for req := range origRequests { + if _, ok := results[req]; !ok { + // The results don't contain anything for this specific request, so + // we've failed to satisfy it from local keys, database keys or from + // all of the fetchers. Report an error. + logrus.Warnf("Failed to retrieve key %q for server %q", req.KeyID, req.ServerName) + return results, fmt.Errorf( + "server key API failed to satisfy key request for server %q key ID %q", + req.ServerName, req.KeyID, + ) + } } + // Return the keys. return results, nil } @@ -90,3 +110,141 @@ func (s *ServerKeyAPI) FetchKeys( func (s *ServerKeyAPI) FetcherName() string { return fmt.Sprintf("ServerKeyAPI (wrapping %q)", s.OurKeyRing.KeyDatabase.FetcherName()) } + +// handleLocalKeys handles cases where the key request contains +// a request for our own server keys. +func (s *ServerKeyAPI) handleLocalKeys( + _ context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) { + for req := range requests { + if req.ServerName == s.ServerName { + // We found a key request that is supposed to be for our own + // keys. Remove it from the request list so we don't hit the + // database or the fetchers for it. + delete(requests, req) + + // Insert our own key into the response. + results[req] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: gomatrixserverlib.VerifyKey{ + Key: gomatrixserverlib.Base64Bytes(s.ServerPublicKey), + }, + ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, + ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(s.ServerKeyValidity)), + } + } + } +} + +// handleDatabaseKeys handles cases where the key requests can be +// satisfied from our local database/cache. +func (s *ServerKeyAPI) handleDatabaseKeys( + ctx context.Context, + now gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + // Ask the database/cache for the keys. + dbResults, err := s.OurKeyRing.KeyDatabase.FetchKeys(ctx, requests) + if err != nil { + return err + } + + // We successfully got some keys. Add them to the results. + for req, res := range dbResults { + // The key we've retrieved from the database/cache might + // have passed its validity period, but right now, it's + // the best thing we've got, and it might be sufficient to + // verify a past event. + results[req] = res + + // If the key is valid right now then we can also remove it + // from the request list as we don't need to fetch it again + // in that case. If the key isn't valid right now, then by + // leaving it in the 'requests' map, we'll try to update the + // key using the fetchers in handleFetcherKeys. + if res.WasValidAt(now, true) { + delete(requests, req) + } + } + return nil +} + +// handleFetcherKeys handles cases where a fetcher can satisfy +// the remaining requests. +func (s *ServerKeyAPI) handleFetcherKeys( + ctx context.Context, + now gomatrixserverlib.Timestamp, + fetcher gomatrixserverlib.KeyFetcher, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + logrus.WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + }).Infof("Fetching %d key(s)", len(requests)) + + // Create a context that limits our requests to 30 seconds. + fetcherCtx, fetcherCancel := context.WithTimeout(ctx, time.Second*30) + defer fetcherCancel() + + // Try to fetch the keys. + fetcherResults, err := fetcher.FetchKeys(fetcherCtx, requests) + if err != nil { + return err + } + + // Build a map of the results that we want to commit to the + // database. We do this in a separate map because otherwise we + // might end up trying to rewrite database entries. + storeResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} + + // Now let's look at the results that we got from this fetcher. + for req, res := range fetcherResults { + if prev, ok := results[req]; ok { + // We've already got a previous entry for this request + // so let's see if the newly retrieved one contains a more + // up-to-date validity period. + if res.ValidUntilTS > prev.ValidUntilTS { + // This key is newer than the one we had so let's store + // it in the database. + if req.ServerName != s.ServerName { + storeResults[req] = res + } + } + } else { + // We didn't already have a previous entry for this request + // so store it in the database anyway for now. + if req.ServerName != s.ServerName { + storeResults[req] = res + } + } + + // Update the results map with this new result. If nothing + // else, we can try verifying against this key. + results[req] = res + + // If the key is valid right now then we can remove it from the + // request list as we won't need to re-fetch it. + if res.WasValidAt(now, true) { + delete(requests, req) + } + } + + // Store the keys from our store map. + if err = s.OurKeyRing.KeyDatabase.StoreKeys(ctx, storeResults); err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + "database_name": s.OurKeyRing.KeyDatabase.FetcherName(), + }).Errorf("Failed to store keys in the database") + return fmt.Errorf("server key API failed to store retrieved keys: %w", err) + } + + if len(storeResults) > 0 { + logrus.WithFields(logrus.Fields{ + "fetcher_name": fetcher.FetcherName(), + }).Infof("Updated %d of %d key(s) in database", len(storeResults), len(results)) + } + + return nil +} diff --git a/serverkeyapi/inthttp/client.go b/serverkeyapi/inthttp/client.go index f22b0e310..39ab8c6c5 100644 --- a/serverkeyapi/inthttp/client.go +++ b/serverkeyapi/inthttp/client.go @@ -4,10 +4,9 @@ import ( "context" "errors" "net/http" - "time" "github.com/matrix-org/dendrite/internal/caching" - internalHTTP "github.com/matrix-org/dendrite/internal/http" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/serverkeyapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/opentracing/opentracing-go" @@ -50,7 +49,7 @@ func (s *httpServerKeyInternalAPI) KeyRing() *gomatrixserverlib.KeyRing { // the other end of the API. return &gomatrixserverlib.KeyRing{ KeyDatabase: s, - KeyFetchers: []gomatrixserverlib.KeyFetcher{s}, + KeyFetchers: []gomatrixserverlib.KeyFetcher{}, } } @@ -90,12 +89,8 @@ func (s *httpServerKeyInternalAPI) FetchKeys( response := api.QueryPublicKeysResponse{ Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), } - now := gomatrixserverlib.AsTimestamp(time.Now()) for req, ts := range requests { - if res, ok := s.cache.GetServerKey(req); ok { - if now > res.ValidUntilTS && res.ExpiredTS == gomatrixserverlib.PublicKeyNotExpired { - continue - } + if res, ok := s.cache.GetServerKey(req, ts); ok { result[req] = res continue } @@ -121,7 +116,7 @@ func (h *httpServerKeyInternalAPI) InputPublicKeys( defer span.Finish() apiURL := h.serverKeyAPIURL + ServerKeyInputPublicKeyPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } func (h *httpServerKeyInternalAPI) QueryPublicKeys( @@ -133,5 +128,5 @@ func (h *httpServerKeyInternalAPI) QueryPublicKeys( defer span.Finish() apiURL := h.serverKeyAPIURL + ServerKeyQueryPublicKeyPath - return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } diff --git a/serverkeyapi/inthttp/server.go b/serverkeyapi/inthttp/server.go index 9efe7d9df..cd4748392 100644 --- a/serverkeyapi/inthttp/server.go +++ b/serverkeyapi/inthttp/server.go @@ -5,16 +5,15 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/serverkeyapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) func AddRoutes(s api.ServerKeyInternalAPI, internalAPIMux *mux.Router, cache caching.ServerKeyCache) { internalAPIMux.Handle(ServerKeyQueryPublicKeyPath, - internal.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse { request := api.QueryPublicKeysRequest{} response := api.QueryPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -29,18 +28,13 @@ func AddRoutes(s api.ServerKeyInternalAPI, internalAPIMux *mux.Router, cache cac }), ) internalAPIMux.Handle(ServerKeyInputPublicKeyPath, - internal.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse { request := api.InputPublicKeysRequest{} response := api.InputPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - store := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) - for req, res := range request.Keys { - store[req] = res - cache.StoreServerKey(req, res) - } - if err := s.StoreKeys(req.Context(), store); err != nil { + if err := s.StoreKeys(req.Context(), request.Keys); err != nil { return util.ErrorResponse(err) } return util.JSONResponse{Code: http.StatusOK, JSON: &response} diff --git a/serverkeyapi/serverkeyapi.go b/serverkeyapi/serverkeyapi.go index 58ca00b73..cddd392ed 100644 --- a/serverkeyapi/serverkeyapi.go +++ b/serverkeyapi/serverkeyapi.go @@ -46,7 +46,11 @@ func NewInternalAPI( } internalAPI := internal.ServerKeyAPI{ - FedClient: fedClient, + ServerName: cfg.Matrix.ServerName, + ServerPublicKey: cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey), + ServerKeyID: cfg.Matrix.KeyID, + ServerKeyValidity: cfg.Matrix.KeyValidityPeriod, + FedClient: fedClient, OurKeyRing: gomatrixserverlib.KeyRing{ KeyFetchers: []gomatrixserverlib.KeyFetcher{ &gomatrixserverlib.DirectKeyFetcher{ diff --git a/serverkeyapi/serverkeyapi_test.go b/serverkeyapi/serverkeyapi_test.go new file mode 100644 index 000000000..3368f5b2a --- /dev/null +++ b/serverkeyapi/serverkeyapi_test.go @@ -0,0 +1,315 @@ +package serverkeyapi + +import ( + "bytes" + "context" + "crypto/ed25519" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "reflect" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type server struct { + name gomatrixserverlib.ServerName // server name + validity time.Duration // key validity duration from now + config *config.Dendrite // skeleton config, from TestMain + fedclient *gomatrixserverlib.FederationClient // uses MockRoundTripper + cache *caching.Caches // server-specific cache + api api.ServerKeyInternalAPI // server-specific server key API +} + +func (s *server) renew() { + // This updates the validity period to be an hour in the + // future, which is particularly useful in server A and + // server C's cases which have validity either as now or + // in the past. + s.validity = time.Hour + s.config.Matrix.KeyValidityPeriod = s.validity +} + +var ( + serverKeyID = gomatrixserverlib.KeyID("ed25519:auto") + serverA = &server{name: "a.com", validity: time.Duration(0)} // expires now + serverB = &server{name: "b.com", validity: time.Hour} // expires in an hour + serverC = &server{name: "c.com", validity: -time.Hour} // expired an hour ago +) + +var servers = map[string]*server{ + "a.com": serverA, + "b.com": serverB, + "c.com": serverC, +} + +func TestMain(m *testing.M) { + // Set up the server key API for each "server" that we + // will use in our tests. + for _, s := range servers { + // Generate a new key. + _, testPriv, err := ed25519.GenerateKey(nil) + if err != nil { + panic("can't generate identity key: " + err.Error()) + } + + // Create a new cache but don't enable prometheus! + s.cache, err = caching.NewInMemoryLRUCache(false) + if err != nil { + panic("can't create cache: " + err.Error()) + } + + // Draw up just enough Dendrite config for the server key + // API to work. + s.config = &config.Dendrite{} + s.config.SetDefaults() + s.config.Matrix.ServerName = gomatrixserverlib.ServerName(s.name) + s.config.Matrix.PrivateKey = testPriv + s.config.Matrix.KeyID = serverKeyID + s.config.Matrix.KeyValidityPeriod = s.validity + s.config.Database.ServerKey = config.DataSource("file::memory:") + + // Create a transport which redirects federation requests to + // the mock round tripper. Since we're not *really* listening for + // federation requests then this will return the key instead. + transport := &http.Transport{} + transport.RegisterProtocol("matrix", &MockRoundTripper{}) + + // Create the federation client. + s.fedclient = gomatrixserverlib.NewFederationClientWithTransport( + s.config.Matrix.ServerName, serverKeyID, testPriv, transport, + ) + + // Finally, build the server key APIs. + s.api = NewInternalAPI(s.config, s.fedclient, s.cache) + } + + // Now that we have built our server key APIs, start the + // rest of the tests. + os.Exit(m.Run()) +} + +type MockRoundTripper struct{} + +func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) { + // Check if the request is looking for keys from a server that + // we know about in the test. The only reason this should go wrong + // is if the test is broken. + s, ok := servers[req.Host] + if !ok { + return nil, fmt.Errorf("server not known: %s", req.Host) + } + + // We're intercepting /matrix/key/v2/server requests here, so check + // that the URL supplied in the request is for that. + if req.URL.Path != "/_matrix/key/v2/server" { + return nil, fmt.Errorf("unexpected request path: %s", req.URL.Path) + } + + // Get the keys and JSON-ify them. + keys := routing.LocalKeys(s.config) + body, err := json.MarshalIndent(keys.JSON, "", " ") + if err != nil { + return nil, err + } + + // And respond. + res = &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader(body)), + } + return +} + +func TestServersRequestOwnKeys(t *testing.T) { + // Each server will request its own keys. There's no reason + // for this to fail as each server should know its own keys. + + for name, s := range servers { + req := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: s.name, + KeyID: serverKeyID, + } + res, err := s.api.FetchKeys( + context.Background(), + map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ + req: gomatrixserverlib.AsTimestamp(time.Now()), + }, + ) + if err != nil { + t.Fatalf("server could not fetch own key: %s", err) + } + if _, ok := res[req]; !ok { + t.Fatalf("server didn't return its own key in the results") + } + t.Logf("%s's key expires at %s\n", name, res[req].ValidUntilTS.Time()) + } +} + +func TestCachingBehaviour(t *testing.T) { + // Server A will request Server B's key, which has a validity + // period of an hour from now. We should retrieve the key and + // it should make it into the cache automatically. + + req := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: serverB.name, + KeyID: serverKeyID, + } + ts := gomatrixserverlib.AsTimestamp(time.Now()) + + res, err := serverA.api.FetchKeys( + context.Background(), + map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ + req: ts, + }, + ) + if err != nil { + t.Fatalf("server A failed to retrieve server B key: %s", err) + } + if len(res) != 1 { + t.Fatalf("server B should have returned one key but instead returned %d keys", len(res)) + } + if _, ok := res[req]; !ok { + t.Fatalf("server B isn't included in the key fetch response") + } + + // At this point, if the previous key request was a success, + // then the cache should now contain the key. Check if that's + // the case - if it isn't then there's something wrong with + // the cache implementation or we failed to get the key. + + cres, ok := serverA.cache.GetServerKey(req, ts) + if !ok { + t.Fatalf("server B key should be in cache but isn't") + } + if !reflect.DeepEqual(cres, res[req]) { + t.Fatalf("the cached result from server B wasn't what server B gave us") + } + + // If we ask the cache for the same key but this time for an event + // that happened in +30 minutes. Since the validity period is for + // another hour, then we should get a response back from the cache. + + _, ok = serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*30)), + ) + if !ok { + t.Fatalf("server B key isn't in cache when it should be (+30 minutes)") + } + + // If we ask the cache for the same key but this time for an event + // that happened in +90 minutes then we should expect to get no + // cache result. This is because the cache shouldn't return a result + // that is obviously past the validity of the event. + + _, ok = serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*90)), + ) + if ok { + t.Fatalf("server B key is in cache when it shouldn't be (+90 minutes)") + } +} + +func TestRenewalBehaviour(t *testing.T) { + // Server A will request Server C's key but their validity period + // is an hour in the past. We'll retrieve the key as, even though it's + // past its validity, it will be able to verify past events. + + req := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: serverC.name, + KeyID: serverKeyID, + } + + res, err := serverA.api.FetchKeys( + context.Background(), + map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ + req: gomatrixserverlib.AsTimestamp(time.Now()), + }, + ) + if err != nil { + t.Fatalf("server A failed to retrieve server C key: %s", err) + } + if len(res) != 1 { + t.Fatalf("server C should have returned one key but instead returned %d keys", len(res)) + } + if _, ok := res[req]; !ok { + t.Fatalf("server C isn't included in the key fetch response") + } + + // If we ask the cache for the server key for an event that happened + // 90 minutes ago then we should get a cache result, as the key hadn't + // passed its validity by that point. The fact that the key is now in + // the cache is, in itself, proof that we successfully retrieved the + // key before. + + oldcached, ok := serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*90)), + ) + if !ok { + t.Fatalf("server C key isn't in cache when it should be (-90 minutes)") + } + + // If we now ask the cache for the same key but this time for an event + // that only happened 30 minutes ago then we shouldn't get a cached + // result, as the event happened after the key validity expired. This + // is really just for sanity checking. + + _, ok = serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)), + ) + if ok { + t.Fatalf("server B key is in cache when it shouldn't be (-30 minutes)") + } + + // We're now going to kick server C into renewing its key. Since we're + // happy at this point that the key that we already have is from the past + // then repeating a key fetch should cause us to try and renew the key. + // If so, then the new key will end up in our cache. + + serverC.renew() + + res, err = serverA.api.FetchKeys( + context.Background(), + map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ + req: gomatrixserverlib.AsTimestamp(time.Now()), + }, + ) + if err != nil { + t.Fatalf("server A failed to retrieve server C key: %s", err) + } + if len(res) != 1 { + t.Fatalf("server C should have returned one key but instead returned %d keys", len(res)) + } + if _, ok = res[req]; !ok { + t.Fatalf("server C isn't included in the key fetch response") + } + + // We're now going to ask the cache what the new key validity is. If + // it is still the same as the previous validity then we've failed to + // retrieve the renewed key. If it's newer then we've successfully got + // the renewed key. + + newcached, ok := serverA.cache.GetServerKey( + req, + gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)), + ) + if !ok { + t.Fatalf("server B key isn't in cache when it shouldn't be (post-renewal)") + } + if oldcached.ValidUntilTS >= newcached.ValidUntilTS { + t.Fatalf("the server B key should have been renewed but wasn't") + } + t.Log(res) +} diff --git a/serverkeyapi/storage/cache/keydb.go b/serverkeyapi/storage/cache/keydb.go index b662e4fdb..2063dfc55 100644 --- a/serverkeyapi/storage/cache/keydb.go +++ b/serverkeyapi/storage/cache/keydb.go @@ -39,8 +39,8 @@ func (d *KeyDatabase) FetchKeys( requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) - for req := range requests { - if res, cached := d.cache.GetServerKey(req); cached { + for req, ts := range requests { + if res, cached := d.cache.GetServerKey(req, ts); cached { results[req] = res delete(requests, req) } diff --git a/serverkeyapi/storage/keydb.go b/serverkeyapi/storage/keydb.go index b9389bd63..c28c4de1e 100644 --- a/serverkeyapi/storage/keydb.go +++ b/serverkeyapi/storage/keydb.go @@ -21,7 +21,7 @@ import ( "golang.org/x/crypto/ed25519" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/serverkeyapi/storage/postgres" "github.com/matrix-org/dendrite/serverkeyapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" @@ -30,7 +30,7 @@ import ( // NewDatabase opens a database connection. func NewDatabase( dataSourceName string, - dbProperties internal.DbProperties, + dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName, serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, diff --git a/serverkeyapi/storage/keydb_wasm.go b/serverkeyapi/storage/keydb_wasm.go index 3d2ca5505..de66a1d63 100644 --- a/serverkeyapi/storage/keydb_wasm.go +++ b/serverkeyapi/storage/keydb_wasm.go @@ -22,7 +22,7 @@ import ( "golang.org/x/crypto/ed25519" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/serverkeyapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) @@ -30,7 +30,7 @@ import ( // NewDatabase opens a database connection. func NewDatabase( dataSourceName string, - dbProperties internal.DbProperties, // nolint:unparam + dbProperties sqlutil.DbProperties, // nolint:unparam serverName gomatrixserverlib.ServerName, serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, diff --git a/serverkeyapi/storage/postgres/keydb.go b/serverkeyapi/storage/postgres/keydb.go index 6282e13bb..aaa4409be 100644 --- a/serverkeyapi/storage/postgres/keydb.go +++ b/serverkeyapi/storage/postgres/keydb.go @@ -17,11 +17,9 @@ package postgres import ( "context" - "time" "golang.org/x/crypto/ed25519" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -38,7 +36,7 @@ type Database struct { // Returns an error if there was a problem talking to the database. func NewDatabase( dataSourceName string, - dbProperties internal.DbProperties, + dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName, serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, @@ -52,28 +50,6 @@ func NewDatabase( if err != nil { return nil, err } - // Store our own keys so that we don't end up making HTTP requests to find our - // own keys - index := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: serverName, - KeyID: serverKeyID, - } - value := gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: gomatrixserverlib.VerifyKey{ - Key: gomatrixserverlib.Base64Bytes(serverKey), - }, - ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(100 * 365 * 24 * time.Hour)), - ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, - } - err = d.StoreKeys( - context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{ - index: value, - }, - ) - if err != nil { - return nil, err - } return d, nil } diff --git a/serverkeyapi/storage/postgres/server_key_table.go b/serverkeyapi/storage/postgres/server_key_table.go index b3c26a480..87f1c211c 100644 --- a/serverkeyapi/storage/postgres/server_key_table.go +++ b/serverkeyapi/storage/postgres/server_key_table.go @@ -19,9 +19,8 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" - "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/serverkeyapi/storage/sqlite3/keydb.go b/serverkeyapi/storage/sqlite3/keydb.go index ab423cff6..dc72b79eb 100644 --- a/serverkeyapi/storage/sqlite3/keydb.go +++ b/serverkeyapi/storage/sqlite3/keydb.go @@ -17,11 +17,9 @@ package sqlite3 import ( "context" - "time" "golang.org/x/crypto/ed25519" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" @@ -48,7 +46,7 @@ func NewDatabase( if err != nil { return nil, err } - db, err := sqlutil.Open(internal.SQLiteDriverName(), cs, nil) + db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil) if err != nil { return nil, err } @@ -57,25 +55,6 @@ func NewDatabase( if err != nil { return nil, err } - // Store our own keys so that we don't end up making HTTP requests to find our - // own keys - index := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: serverName, - KeyID: serverKeyID, - } - value := gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: gomatrixserverlib.VerifyKey{ - Key: gomatrixserverlib.Base64Bytes(serverKey), - }, - ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(100 * 365 * 24 * time.Hour)), - ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, - } - err = d.StoreKeys( - context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{ - index: value, - }, - ) if err != nil { return nil, err } diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index ae24a14d4..4f03dccbb 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -90,7 +91,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys( nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) } - query := strings.Replace(bulkSelectServerKeysSQL, "($1)", internal.QueryVariadic(len(nameAndKeyIDs)), 1) + query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1) iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) for i, v := range nameAndKeyIDs { diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 4d43e8111..ad6290e3f 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -21,6 +21,7 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types" @@ -67,7 +68,7 @@ func (s *OutputClientDataConsumer) Start() error { // sync stream position may race and be incorrectly calculated. func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error { // Parse out the event JSON - var output internal.AccountData + var output eventutil.AccountData if err := json.Unmarshal(msg.Value, &output); err != nil { // If the message was invalid, log it and move on to the next message in the stream log.WithError(err).Errorf("client API server output log: message parse failure") diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 055f76603..98be5bb73 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -17,7 +17,6 @@ package consumers import ( "context" "encoding/json" - "fmt" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" @@ -99,23 +98,10 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( ctx context.Context, msg api.OutputNewRoomEvent, ) error { ev := msg.Event - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "room_id": ev.RoomID(), - "room_version": ev.RoomVersion, - }).Info("received event from roomserver") - addsStateEvents, err := s.lookupStateEvents(msg.AddsStateEventIDs, ev) - if err != nil { - log.WithFields(log.Fields{ - "event": string(ev.JSON()), - log.ErrorKey: err, - "add": msg.AddsStateEventIDs, - "del": msg.RemovesStateEventIDs, - }).Panicf("roomserver output log: state event lookup failure") - } + addsStateEvents := msg.AddsState() - ev, err = s.updateStateEvent(ev) + ev, err := s.updateStateEvent(ev) if err != nil { return err } @@ -185,63 +171,6 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( return nil } -// lookupStateEvents looks up the state events that are added by a new event. -func (s *OutputRoomEventConsumer) lookupStateEvents( - addsStateEventIDs []string, event gomatrixserverlib.HeaderedEvent, -) ([]gomatrixserverlib.HeaderedEvent, error) { - // Fast path if there aren't any new state events. - if len(addsStateEventIDs) == 0 { - return nil, nil - } - - // Fast path if the only state event added is the event itself. - if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { - return []gomatrixserverlib.HeaderedEvent{event}, nil - } - - // Check if this is re-adding a state events that we previously processed - // If we have previously received a state event it may still be in - // our event database. - result, err := s.db.Events(context.TODO(), addsStateEventIDs) - if err != nil { - return nil, err - } - missing := missingEventsFrom(result, addsStateEventIDs) - - // Check if event itself is being added. - for _, eventID := range missing { - if eventID == event.EventID() { - result = append(result, event) - break - } - } - missing = missingEventsFrom(result, addsStateEventIDs) - - if len(missing) == 0 { - return result, nil - } - - // At this point the missing events are neither the event itself nor are - // they present in our local database. Our only option is to fetch them - // from the roomserver using the query API. - eventReq := api.QueryEventsByIDRequest{EventIDs: missing} - var eventResp api.QueryEventsByIDResponse - if err := s.rsAPI.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { - return nil, err - } - - result = append(result, eventResp.Events...) - missing = missingEventsFrom(result, addsStateEventIDs) - - if len(missing) != 0 { - return nil, fmt.Errorf( - "missing %d state events IDs at event %q", len(missing), event.EventID(), - ) - } - - return result, nil -} - func (s *OutputRoomEventConsumer) updateStateEvent(event gomatrixserverlib.HeaderedEvent) (gomatrixserverlib.HeaderedEvent, error) { var stateKey string if event.StateKey() == nil { @@ -270,17 +199,3 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event gomatrixserverlib.Heade event.Event, err = event.SetUnsigned(prev) return event, err } - -func missingEventsFrom(events []gomatrixserverlib.HeaderedEvent, required []string) []string { - have := map[string]bool{} - for _, event := range events { - have[event.EventID()] = true - } - var missing []string - for _, eventID := range required { - if !have[eventID] { - missing = append(missing, eventID) - } - } - return missing -} diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 8c8976345..15add1b45 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -158,6 +158,7 @@ func OnIncomingMessagesRequest( util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") return jsonerror.InternalServerError() } + util.GetLogger(req.Context()).WithFields(logrus.Fields{ "from": from.String(), "to": to.String(), @@ -246,6 +247,12 @@ func (r *messagesReq) retrieveEvents() ( // change the way topological positions are defined (as depth isn't the most // reliable way to define it), it would be easier and less troublesome to // only have to change it in one place, i.e. the database. + start, end, err = r.getStartEnd(events) + + return clientEvents, start, end, err +} + +func (r *messagesReq) getStartEnd(events []gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { start, err = r.db.EventPositionInTopology( r.ctx, events[0].EventID(), ) @@ -253,24 +260,28 @@ func (r *messagesReq) retrieveEvents() ( err = fmt.Errorf("EventPositionInTopology: for start event %s: %w", events[0].EventID(), err) return } - end, err = r.db.EventPositionInTopology( - r.ctx, events[len(events)-1].EventID(), - ) - if err != nil { - err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err) - return + if r.backwardOrdering && events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate { + // We've hit the beginning of the room so there's really nowhere else + // to go. This seems to fix Riot iOS from looping on /messages endlessly. + end = types.NewTopologyToken(0, 0) + } else { + end, err = r.db.EventPositionInTopology( + r.ctx, events[len(events)-1].EventID(), + ) + if err != nil { + err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err) + return + } + if r.backwardOrdering { + // A stream/topological position is a cursor located between two events. + // While they are identified in the code by the event on their right (if + // we consider a left to right chronological order), tokens need to refer + // to them by the event on their left, therefore we need to decrement the + // end position we send in the response if we're going backward. + end.Decrement() + } } - - if r.backwardOrdering { - // A stream/topological position is a cursor located between two events. - // While they are identified in the code by the event on their right (if - // we consider a left to right chronological order), tokens need to refer - // to them by the event on their left, therefore we need to decrement the - // end position we send in the response if we're going backward. - end.Decrement() - } - - return clientEvents, start, end, err + return } // handleEmptyEventsSlice handles the case where the initial request to the @@ -375,15 +386,15 @@ func (e eventsByDepth) Less(i, j int) bool { // Returns an error if there was an issue with retrieving the list of servers in // the room or sending the request. func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]string, limit int) ([]gomatrixserverlib.HeaderedEvent, error) { - var res api.QueryBackfillResponse - err := r.rsAPI.QueryBackfill(context.Background(), &api.QueryBackfillRequest{ + var res api.PerformBackfillResponse + err := r.rsAPI.PerformBackfill(context.Background(), &api.PerformBackfillRequest{ RoomID: roomID, BackwardsExtremities: backwardsExtremities, Limit: limit, ServerName: r.cfg.Matrix.ServerName, }, &res) if err != nil { - return nil, fmt.Errorf("QueryBackfill failed: %w", err) + return nil, fmt.Errorf("PerformBackfill failed: %w", err) } util.GetLogger(r.ctx).WithField("new_events", len(res.Events)).Info("Storing new events from backfill") diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index c876164d7..5744de05a 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -18,14 +18,12 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -39,25 +37,19 @@ const pathPrefixR0 = "/client/r0" // nolint: gocyclo func Setup( publicAPIMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, - deviceDB devices.Database, federation *gomatrixserverlib.FederationClient, + userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, rsAPI api.RoomserverInternalAPI, cfg *config.Dendrite, ) { r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() - authData := auth.Data{ - AccountDB: nil, - DeviceDB: deviceDB, - AppServices: nil, - } - // TODO: Add AS support for all handlers below. - r0mux.Handle("/sync", internal.MakeAuthAPI("sync", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/messages", internal.MakeAuthAPI("room_messages", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 566e5d589..7b3bd6785 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -18,11 +18,11 @@ import ( "context" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -57,10 +57,10 @@ type Database interface { // from when the device sent the event via an API that included a transaction // ID. A response object must be provided for IncrementaSync to populate - it // will not create one. - IncrementalSync(ctx context.Context, res *types.Response, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + IncrementalSync(ctx context.Context, res *types.Response, device userapi.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) // CompleteSync returns a complete /sync API response for the given user. A response object // must be provided for CompleteSync to populate - it will not create one. - CompleteSync(ctx context.Context, res *types.Response, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error) + CompleteSync(ctx context.Context, res *types.Response, device userapi.Device, numRecentEventsPerRoom int) (*types.Response, error) // GetAccountDataInRange returns all account data for a given user inserted or // updated between two given positions // Returns a map following the format data[roomID] = []dataTypes @@ -103,7 +103,7 @@ type Database interface { // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. - StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent + StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) // AddSendToDevice increases the EDU position in the cache and returns the stream position. diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index e0b87bd6d..67eb1e863 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -136,7 +137,7 @@ func (s *accountDataStatements) SelectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := internal.TxStmt(txn, s.selectMaxAccountDataIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index cc0fb6b1e..25906edb4 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -165,7 +166,7 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( userID string, membership string, // nolint: unparam ) ([]string, error) { - stmt := internal.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) rows, err := stmt.QueryContext(ctx, userID, membership) if err != nil { return nil, err @@ -188,7 +189,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) ([]gomatrixserverlib.HeaderedEvent, error) { - stmt := internal.TxStmt(txn, s.selectCurrentStateStmt) + stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(stateFilter.Senders), pq.StringArray(stateFilter.NotSenders), @@ -208,7 +209,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - stmt := internal.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) _, err := stmt.ExecContext(ctx, eventID) return err } @@ -231,7 +232,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( } // upsert state event - stmt := internal.TxStmt(txn, s.upsertRoomStateStmt) + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) _, err = stmt.ExecContext( ctx, event.RoomID(), @@ -250,7 +251,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( func (s *currentRoomStateStatements) SelectEventsWithEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { - stmt := internal.TxStmt(txn, s.selectEventsWithEventIDsStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventsWithEventIDsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index caa12d2f8..5031d64e5 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -21,6 +21,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -119,7 +120,7 @@ func (s *inviteEventsStatements) DeleteInviteEvent( func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, ) (map[string]gomatrixserverlib.HeaderedEvent, error) { - stmt := internal.TxStmt(txn, s.selectInviteEventsInRangeStmt) + stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { return nil, err @@ -149,7 +150,7 @@ func (s *inviteEventsStatements) SelectMaxInviteID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := internal.TxStmt(txn, s.selectMaxInviteIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index ae6dfb381..f01b2eabd 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -21,12 +21,13 @@ import ( "encoding/json" "sort" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -158,7 +159,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { - stmt := internal.TxStmt(txn, s.selectStateInRangeStmt) + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( ctx, r.Low(), r.High(), @@ -239,7 +240,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := internal.TxStmt(txn, s.selectMaxEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 @@ -275,7 +276,7 @@ func (s *outputRoomEventsStatements) InsertEvent( return } - stmt := internal.TxStmt(txn, s.insertEventStmt) + stmt := sqlutil.TxStmt(txn, s.insertEventStmt) err = stmt.QueryRowContext( ctx, event.RoomID(), @@ -303,9 +304,9 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( ) ([]types.StreamEvent, error) { var stmt *sql.Stmt if onlySyncEvents { - stmt = internal.TxStmt(txn, s.selectRecentEventsForSyncStmt) + stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) } else { - stmt = internal.TxStmt(txn, s.selectRecentEventsStmt) + stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) if err != nil { @@ -333,7 +334,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, ) ([]types.StreamEvent, error) { - stmt := internal.TxStmt(txn, s.selectEarlyEventsStmt) + stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) if err != nil { return nil, err @@ -357,7 +358,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( func (s *outputRoomEventsStatements) SelectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { - stmt := internal.TxStmt(txn, s.selectEventsStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index ffbeece30..1ab3a1dc2 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index 335a05ef1..07af9ad6b 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -107,14 +108,14 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ) (err error) { - _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) return } func (s *sendToDeviceStatements) CountSendToDeviceMessages( ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (count int, err error) { - row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) if err = row.Scan(&count); err != nil { return } @@ -124,7 +125,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages( ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (events []types.SendToDeviceEvent, err error) { - rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) + rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) if err != nil { return } diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 8a8f964a5..573586cc7 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -18,12 +18,10 @@ package postgres import ( "database/sql" - "github.com/matrix-org/dendrite/internal/sqlutil" - // Import the postgres database driver. _ "github.com/lib/pq" "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/shared" ) @@ -32,11 +30,11 @@ import ( type SyncServerDatasource struct { shared.Database db *sql.DB - internal.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements } // NewDatabase creates a new sync server database -func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (*SyncServerDatasource, error) { +func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error if d.db, err = sqlutil.Open("postgres", dbDataSourceName, dbProperties); err != nil { @@ -82,7 +80,7 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (* CurrentRoomState: currState, BackwardExtremities: backwardExtremities, SendToDevice: sendToDevice, - SendToDeviceWriter: internal.NewTransactionWriter(), + SendToDeviceWriter: sqlutil.NewTransactionWriter(), EDUCache: cache.New(), } return &d, nil diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 497c043a6..74ae3eabd 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -21,9 +21,10 @@ import ( "fmt" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -42,7 +43,7 @@ type Database struct { CurrentRoomState tables.CurrentRoomState BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice - SendToDeviceWriter *internal.TransactionWriter + SendToDeviceWriter *sqlutil.TransactionWriter EDUCache *cache.EDUCache } @@ -126,7 +127,7 @@ func (d *Database) GetStateEvent( func (d *Database) GetStateEventsForRoom( ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { - err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter) return err }) @@ -136,7 +137,7 @@ func (d *Database) GetStateEventsForRoom( func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { var maxID int64 var err error - err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { maxID, err = d.OutputEvents.SelectMaxEventID(ctx, txn) if err != nil { return err @@ -168,7 +169,7 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition func (d *Database) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ) (sp types.StreamPosition, err error) { - err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) return err }) @@ -207,14 +208,14 @@ func (d *Database) GetAccountDataInRange( func (d *Database) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, ) (sp types.StreamPosition, err error) { - err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) return err }) return } -func (d *Database) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent { +func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent { out := make([]gomatrixserverlib.HeaderedEvent, len(in)) for i := 0; i < len(in); i++ { out[i] = in[i].HeaderedEvent @@ -275,7 +276,7 @@ func (d *Database) WriteEvent( addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool, ) (pduPosition types.StreamPosition, returnErr error) { - returnErr = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { var err error pos, err := d.OutputEvents.InsertEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, @@ -375,7 +376,7 @@ func (d *Database) GetEventsInTopologicalRange( } func (d *Database) SyncPosition(ctx context.Context) (tok types.StreamingToken, err error) { - err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { pos, err := d.syncPositionTx(ctx, txn) if err != nil { return err @@ -442,7 +443,7 @@ func (d *Database) syncPositionTx( // IDs of all rooms the user joined are returned so EDU deltas can be added for them. func (d *Database) addPDUDeltaToResponse( ctx context.Context, - device authtypes.Device, + device userapi.Device, r types.Range, numRecentEventsPerRoom int, wantFullState bool, @@ -454,7 +455,7 @@ func (d *Database) addPDUDeltaToResponse( } var succeeded bool defer func() { - txerr := internal.EndTransaction(txn, &succeeded) + txerr := sqlutil.EndTransaction(txn, &succeeded) if err == nil && txerr != nil { err = txerr } @@ -549,7 +550,7 @@ func (d *Database) addEDUDeltaToResponse( func (d *Database) IncrementalSync( ctx context.Context, res *types.Response, - device authtypes.Device, + device userapi.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool, @@ -608,7 +609,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync( } var succeeded bool defer func() { - txerr := internal.EndTransaction(txn, &succeeded) + txerr := sqlutil.EndTransaction(txn, &succeeded) if err == nil && txerr != nil { err = txerr } @@ -687,7 +688,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync( func (d *Database) CompleteSync( ctx context.Context, res *types.Response, - device authtypes.Device, numRecentEventsPerRoom int, + device userapi.Device, numRecentEventsPerRoom int, ) (*types.Response, error) { toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( ctx, res, device.UserID, numRecentEventsPerRoom, @@ -758,7 +759,7 @@ func (d *Database) getBackwardTopologyPos( // addRoomDeltaToResponse adds a room state delta to a sync response func (d *Database) addRoomDeltaToResponse( ctx context.Context, - device *authtypes.Device, + device *userapi.Device, txn *sql.Tx, r types.Range, delta stateDelta, @@ -904,7 +905,7 @@ func (d *Database) fetchMissingStateEvents( // the user has new membership events. // A list of joined room IDs is also returned in case the caller needs it. func (d *Database) getStateDeltas( - ctx context.Context, device *authtypes.Device, txn *sql.Tx, + ctx context.Context, device *userapi.Device, txn *sql.Tx, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter, ) ([]stateDelta, []string, error) { @@ -979,7 +980,7 @@ func (d *Database) getStateDeltas( // Fetches full state for all joined rooms and uses selectStateInRange to get // updates for other rooms. func (d *Database) getStateDeltasForFullStateSync( - ctx context.Context, device *authtypes.Device, txn *sql.Tx, + ctx context.Context, device *userapi.Device, txn *sql.Tx, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter, ) ([]stateDelta, []string, error) { diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index a3c797342..ae5caa4e5 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index b0cf1297c..85f212ad8 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -152,7 +153,7 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( userID string, membership string, // nolint: unparam ) ([]string, error) { - stmt := internal.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) rows, err := stmt.QueryContext(ctx, userID, membership) if err != nil { return nil, err @@ -175,7 +176,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, ) ([]gomatrixserverlib.HeaderedEvent, error) { - stmt := internal.TxStmt(txn, s.selectCurrentStateStmt) + stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) rows, err := stmt.QueryContext(ctx, roomID, nil, // FIXME: pq.StringArray(stateFilterPart.Senders), nil, // FIXME: pq.StringArray(stateFilterPart.NotSenders), @@ -195,7 +196,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - stmt := internal.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) _, err := stmt.ExecContext(ctx, eventID) return err } @@ -218,7 +219,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( } // upsert state event - stmt := internal.TxStmt(txn, s.upsertRoomStateStmt) + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) _, err = stmt.ExecContext( ctx, event.RoomID(), @@ -241,7 +242,7 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( for k, v := range eventIDs { iEventIDs[k] = v } - query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) + query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) rows, err := txn.QueryContext(ctx, query, iEventIDs...) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 64b523394..bb58e3456 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -21,6 +21,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -123,7 +124,7 @@ func (s *inviteEventsStatements) DeleteInviteEvent( func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, ) (map[string]gomatrixserverlib.HeaderedEvent, error) { - stmt := internal.TxStmt(txn, s.selectInviteEventsInRangeStmt) + stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { return nil, err @@ -153,7 +154,7 @@ func (s *inviteEventsStatements) SelectMaxInviteID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := internal.TxStmt(txn, s.selectMaxInviteIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 4a84ecc6c..367ab3c9a 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -21,11 +21,12 @@ import ( "encoding/json" "sort" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -149,7 +150,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, stateFilterPart *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { - stmt := internal.TxStmt(txn, s.selectStateInRangeStmt) + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( ctx, r.Low(), r.High(), @@ -236,7 +237,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := internal.TxStmt(txn, s.selectMaxEventIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 @@ -286,7 +287,7 @@ func (s *outputRoomEventsStatements) InsertEvent( return } - insertStmt := internal.TxStmt(txn, s.insertEventStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) _, err = insertStmt.ExecContext( ctx, streamPos, @@ -313,9 +314,9 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( ) ([]types.StreamEvent, error) { var stmt *sql.Stmt if onlySyncEvents { - stmt = internal.TxStmt(txn, s.selectRecentEventsForSyncStmt) + stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) } else { - stmt = internal.TxStmt(txn, s.selectRecentEventsStmt) + stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) @@ -342,7 +343,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, ) ([]types.StreamEvent, error) { - stmt := internal.TxStmt(txn, s.selectEarlyEventsStmt) + stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) if err != nil { return nil, err @@ -367,7 +368,7 @@ func (s *outputRoomEventsStatements) SelectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { var returnEvents []types.StreamEvent - stmt := internal.TxStmt(txn, s.selectEventsStmt) + stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) for _, eventID := range eventIDs { rows, err := stmt.QueryContext(ctx, eventID) if err != nil { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 0d727cb18..811dfa4f3 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,7 +18,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -102,7 +102,7 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { - stmt := internal.TxStmt(txn, s.insertEventInTopologyStmt) + stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) _, err = stmt.ExecContext( ctx, event.EventID(), event.Depth(), event.RoomID(), pos, ) @@ -118,9 +118,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( // is requested or not. var stmt *sql.Stmt if chronologicalOrder { - stmt = internal.TxStmt(txn, s.selectEventIDsInRangeASCStmt) + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt) } else { - stmt = internal.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) } // Query the event IDs. @@ -149,7 +149,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( ctx context.Context, txn *sql.Tx, eventID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { - stmt := internal.TxStmt(txn, s.selectPositionInTopologyStmt) + stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt) err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) return } @@ -157,7 +157,7 @@ func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( ctx context.Context, txn *sql.Tx, roomID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { - stmt := internal.TxStmt(txn, s.selectMaxPositionInTopologyStmt) + stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 0d03f23ef..42bd3c19a 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -97,14 +98,14 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ) (err error) { - _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) return } func (s *sendToDeviceStatements) CountSendToDeviceMessages( ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (count int, err error) { - row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) if err = row.Scan(&count); err != nil { return } @@ -114,7 +115,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages( ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (events []types.SendToDeviceEvent, err error) { - rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) + rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) if err != nil { return } @@ -149,7 +150,7 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ) (err error) { - query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", internal.QueryVariadic(1+len(nids)), 1) + query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", sqlutil.QueryVariadic(1+len(nids)), 1) params := make([]interface{}, 1+len(nids)) params[0] = token for k, v := range nids { @@ -162,7 +163,7 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ) (err error) { - query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", internal.QueryVariadic(len(nids)), 1) + query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) params := make([]interface{}, 1+len(nids)) for k, v := range nids { params[k] = v diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index 3bfe22dd0..57abd9c44 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -4,7 +4,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -46,8 +46,8 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { } func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { - increaseStmt := internal.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := internal.TxStmt(txn, s.selectStreamIDStmt) + increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { return } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 38ce5bcf0..51cdbe325 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -18,13 +18,11 @@ package sqlite3 import ( "database/sql" - "github.com/matrix-org/dendrite/internal/sqlutil" - // Import the sqlite3 package _ "github.com/mattn/go-sqlite3" "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/shared" ) @@ -33,7 +31,7 @@ import ( type SyncServerDatasource struct { shared.Database db *sql.DB - internal.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements streamID streamIDStatements } @@ -45,7 +43,7 @@ func NewDatabase(dataSourceName string) (*SyncServerDatasource, error) { if err != nil { return nil, err } - if d.db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil); err != nil { + if d.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } if err = d.prepare(); err != nil { @@ -98,7 +96,7 @@ func (d *SyncServerDatasource) prepare() (err error) { CurrentRoomState: roomState, Topology: topology, SendToDevice: sendToDevice, - SendToDeviceWriter: internal.NewTransactionWriter(), + SendToDeviceWriter: sqlutil.NewTransactionWriter(), EDUCache: cache.New(), } return nil diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index a4275da0c..ea69da3bc 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -19,13 +19,13 @@ package storage import ( "net/url" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) // NewSyncServerDatasource opens a database connection. -func NewSyncServerDatasource(dataSourceName string, dbProperties internal.DbProperties) (Database, error) { +func NewSyncServerDatasource(dataSourceName string, dbProperties sqlutil.DbProperties) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return postgres.NewDatabase(dataSourceName, dbProperties) diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 4661ede4d..85084facb 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -8,10 +8,10 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -22,7 +22,7 @@ var ( testRoomID = fmt.Sprintf("!hallownest:%s", testOrigin) testUserIDA = fmt.Sprintf("@hornet:%s", testOrigin) testUserIDB = fmt.Sprintf("@paleking:%s", testOrigin) - testUserDeviceA = authtypes.Device{ + testUserDeviceA = userapi.Device{ UserID: testUserIDA, ID: "device_id_A", DisplayName: "Device A", diff --git a/syncapi/storage/storage_wasm.go b/syncapi/storage/storage_wasm.go index a2b269a14..0886b8c21 100644 --- a/syncapi/storage/storage_wasm.go +++ b/syncapi/storage/storage_wasm.go @@ -18,14 +18,14 @@ import ( "fmt" "net/url" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) // NewPublicRoomsServerDatabase opens a database connection. func NewSyncServerDatasource( dataSourceName string, - dbProperties internal.DbProperties, // nolint:unparam + dbProperties sqlutil.DbProperties, // nolint:unparam ) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 132315573..ecc4fcbfc 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -22,9 +22,8 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -357,7 +356,7 @@ func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStre func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syncRequest { return syncRequest{ - device: authtypes.Device{ + device: userapi.Device{ UserID: userID, ID: deviceID, }, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index c7796b561..beeaa40f7 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -21,9 +21,8 @@ import ( "strconv" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -42,7 +41,7 @@ type filter struct { // syncRequest represents a /sync request, with sensible defaults/sanity checks applied. type syncRequest struct { ctx context.Context - device authtypes.Device + device userapi.Device limit int timeout time.Duration since *types.StreamingToken // nil means that no since token was supplied @@ -50,7 +49,7 @@ type syncRequest struct { log *log.Entry } -func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, error) { +func newSyncRequest(req *http.Request, device userapi.Device) (*syncRequest, error) { timeout := getTimeout(req.URL.Query().Get("timeout")) fullState := req.URL.Query().Get("full_state") wantFullState := fullState != "" && fullState != "false" diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 8b93cad45..26b925eac 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -21,11 +21,10 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -33,20 +32,20 @@ import ( // RequestPool manages HTTP long-poll connections for /sync type RequestPool struct { - db storage.Database - accountDB accounts.Database - notifier *Notifier + db storage.Database + userAPI userapi.UserInternalAPI + notifier *Notifier } // NewRequestPool makes a new RequestPool -func NewRequestPool(db storage.Database, n *Notifier, adb accounts.Database) *RequestPool { - return &RequestPool{db, adb, n} +func NewRequestPool(db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI) *RequestPool { + return &RequestPool{db, userAPI, n} } // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be // called in a dedicated goroutine for this request. This function will block the goroutine // until a response is ready, or it times out. -func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtypes.Device) util.JSONResponse { +func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.Device) util.JSONResponse { var syncData *types.Response // Extract values from request @@ -193,6 +192,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea return } +// nolint:gocyclo func (rp *RequestPool) appendAccountData( data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition, accountDataFilter *gomatrixserverlib.EventFilter, @@ -202,25 +202,21 @@ func (rp *RequestPool) appendAccountData( // data keys were set between two message. This isn't a huge issue since the // duplicate data doesn't represent a huge quantity of data, but an optimisation // here would be making sure each data is sent only once to the client. - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return nil, err - } - if req.since == nil { // If this is the initial sync, we don't need to check if a data has // already been sent. Instead, we send the whole batch. - var global []gomatrixserverlib.ClientEvent - var rooms map[string][]gomatrixserverlib.ClientEvent - global, rooms, err = rp.accountDB.GetAccountData(req.ctx, localpart) + var res userapi.QueryAccountDataResponse + err := rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{ + UserID: userID, + }, &res) if err != nil { return nil, err } - data.AccountData.Events = global + data.AccountData.Events = res.GlobalAccountData for r, j := range data.Rooms.Join { - if len(rooms[r]) > 0 { - j.AccountData.Events = rooms[r] + if len(res.RoomAccountData[r]) > 0 { + j.AccountData.Events = res.RoomAccountData[r] data.Rooms.Join[r] = j } } @@ -256,13 +252,20 @@ func (rp *RequestPool) appendAccountData( events := []gomatrixserverlib.ClientEvent{} // Request the missing data from the database for _, dataType := range dataTypes { - event, err := rp.accountDB.GetAccountDataByType( - req.ctx, localpart, roomID, dataType, - ) + var res userapi.QueryAccountDataResponse + err = rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: dataType, + }, &res) if err != nil { return nil, err } - events = append(events, *event) + if len(res.RoomAccountData[roomID]) > 0 { + events = append(events, res.RoomAccountData[roomID]...) + } else if len(res.GlobalAccountData) > 0 { + events = append(events, res.GlobalAccountData...) + } } // Append the data to the response diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 40e652af4..caf91e27e 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -21,12 +21,11 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/syncapi/consumers" "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/storage" @@ -38,8 +37,7 @@ import ( func AddPublicRoutes( router *mux.Router, consumer sarama.Consumer, - deviceDB devices.Database, - accountsDB accounts.Database, + userAPI userapi.UserInternalAPI, rsAPI api.RoomserverInternalAPI, federation *gomatrixserverlib.FederationClient, cfg *config.Dendrite, @@ -60,7 +58,7 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start notifier") } - requestPool := sync.NewRequestPool(syncDB, notifier, accountsDB) + requestPool := sync.NewRequestPool(syncDB, notifier, userAPI) roomConsumer := consumers.NewOutputRoomEventConsumer( cfg, consumer, notifier, syncDB, rsAPI, @@ -90,5 +88,5 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start send-to-device consumer") } - routing.Setup(router, requestPool, syncDB, deviceDB, federation, rsAPI, cfg) + routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg) } diff --git a/sytest-whitelist b/sytest-whitelist index c1c0c13ac..971f8cf0d 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -53,6 +53,8 @@ PUT /rooms/:room_id/send/:event_type/:txn_id deduplicates the same txn id GET /rooms/:room_id/state/m.room.power_levels can fetch levels PUT /rooms/:room_id/state/m.room.power_levels can set levels PUT power_levels should not explode if the old power levels were empty +GET /rooms/:room_id/state/m.room.member/:user_id?format=event fetches my membership event +GET /rooms/:room_id/joined_members fetches my membership Both GET and PUT work POST /rooms/:room_id/read_markers can create read marker User signups are forbidden from starting with '_' @@ -113,6 +115,8 @@ User can invite local user to room with version 1 Should reject keys claiming to belong to a different user Can add account data Can add account data to room +Can get account data without syncing +Can get room account data without syncing #Latest account data appears in v2 /sync New account data appears in incremental v2 /sync Checking local federation server @@ -313,3 +317,33 @@ Invalid JSON integers Invalid JSON special values Invalid JSON floats Outbound federation will ignore a missing event with bad JSON for room version 6 +Can download without a file name over federation +POST /media/r0/upload can create an upload +GET /media/r0/download can fetch the value again +Remote users can join room by alias +Alias creators can delete alias with no ops +Alias creators can delete canonical alias with no ops +Room members can override their displayname on a room-specific basis +displayname updates affect room member events +avatar_url updates affect room member events +Real non-joined users can get individual state for world_readable rooms after leaving +Can upload with Unicode file name +POSTed media can be thumbnailed +Remote media can be thumbnailed +Can download with Unicode file name locally +Can download file 'ascii' +Can download file 'name with spaces' +Can download file 'name;with;semicolons' +Can download specifying a different ASCII file name +Can download with Unicode file name over federation +Can download specifying a different Unicode file name +Inbound /v1/send_join rejects joins from other servers +Outbound federation can query v1 /send_join +Inbound /v1/send_join rejects incorrectly-signed joins +POST /rooms/:room_id/state/m.room.name sets name +GET /rooms/:room_id/state/m.room.name gets name +POST /rooms/:room_id/state/m.room.topic sets topic +GET /rooms/:room_id/state/m.room.topic gets topic +GET /rooms/:room_id/state fetches entire room state +Setting room topic reports m.room.topic to myself +setting 'm.room.name' respects room powerlevel diff --git a/userapi/api/api.go b/userapi/api/api.go new file mode 100644 index 000000000..c953a5bac --- /dev/null +++ b/userapi/api/api.go @@ -0,0 +1,182 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +// UserInternalAPI is the internal API for information about users and devices. +type UserInternalAPI interface { + PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error + PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error + QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error + QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error + QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error + QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error +} + +// QueryAccessTokenRequest is the request for QueryAccessToken +type QueryAccessTokenRequest struct { + AccessToken string + // optional user ID, valid only if the token is an appservice. + // https://matrix.org/docs/spec/application_service/r0.1.2#using-sync-and-events + AppServiceUserID string +} + +// QueryAccessTokenResponse is the response for QueryAccessToken +type QueryAccessTokenResponse struct { + Device *Device + Err error // e.g ErrorForbidden +} + +// QueryAccountDataRequest is the request for QueryAccountData +type QueryAccountDataRequest struct { + UserID string // required: the user to get account data for. + // TODO: This is a terribly confusing API shape :/ + DataType string // optional: if specified returns only a single event matching this data type. + // optional: Only used if DataType is set. If blank returns global account data matching the data type. + // If set, returns only room account data matching this data type. + RoomID string +} + +// QueryAccountDataResponse is the response for QueryAccountData +type QueryAccountDataResponse struct { + GlobalAccountData []gomatrixserverlib.ClientEvent + RoomAccountData map[string][]gomatrixserverlib.ClientEvent +} + +// QueryDevicesRequest is the request for QueryDevices +type QueryDevicesRequest struct { + UserID string +} + +// QueryDevicesResponse is the response for QueryDevices +type QueryDevicesResponse struct { + UserExists bool + Devices []Device +} + +// QueryProfileRequest is the request for QueryProfile +type QueryProfileRequest struct { + // The user ID to query + UserID string +} + +// QueryProfileResponse is the response for QueryProfile +type QueryProfileResponse struct { + // True if the user exists. Querying for a profile does not create them. + UserExists bool + // The current display name if set. + DisplayName string + // The current avatar URL if set. + AvatarURL string +} + +// PerformAccountCreationRequest is the request for PerformAccountCreation +type PerformAccountCreationRequest struct { + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + + AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. + Password string // optional: if missing then this account will be a passwordless account + OnConflict Conflict +} + +// PerformAccountCreationResponse is the response for PerformAccountCreation +type PerformAccountCreationResponse struct { + AccountCreated bool + Account *Account +} + +// PerformDeviceCreationRequest is the request for PerformDeviceCreation +type PerformDeviceCreationRequest struct { + Localpart string + AccessToken string // optional: if blank one will be made on your behalf + // optional: if nil an ID is generated for you. If set, replaces any existing device session, + // which will generate a new access token and invalidate the old one. + DeviceID *string + // optional: if nil no display name will be associated with this device. + DeviceDisplayName *string +} + +// PerformDeviceCreationResponse is the response for PerformDeviceCreation +type PerformDeviceCreationResponse struct { + DeviceCreated bool + Device *Device +} + +// Device represents a client's device (mobile, web, etc) +type Device struct { + ID string + UserID string + // The access_token granted to this device. + // This uniquely identifies the device from all other devices and clients. + AccessToken string + // The unique ID of the session identified by the access token. + // Can be used as a secure substitution in places where data needs to be + // associated with access tokens. + SessionID int64 + // TODO: display name, last used timestamp, keys, etc + DisplayName string +} + +// Account represents a Matrix account on this home server. +type Account struct { + UserID string + Localpart string + ServerName gomatrixserverlib.ServerName + AppServiceID string + // TODO: Other flags like IsAdmin, IsGuest + // TODO: Associations (e.g. with application services) +} + +// ErrorForbidden is an error indicating that the supplied access token is forbidden +type ErrorForbidden struct { + Message string +} + +func (e *ErrorForbidden) Error() string { + return "Forbidden: " + e.Message +} + +// ErrorConflict is an error indicating that there was a conflict which resulted in the request being aborted. +type ErrorConflict struct { + Message string +} + +func (e *ErrorConflict) Error() string { + return "Conflict: " + e.Message +} + +// Conflict is an enum representing what to do when encountering conflicting when creating profiles/devices +type Conflict int + +// AccountType is an enum representing the kind of account +type AccountType int + +const ( + // ConflictUpdate will update matching records returning no error + ConflictUpdate Conflict = 1 + // ConflictAbort will reject the request with ErrorConflict + ConflictAbort Conflict = 2 + + // AccountTypeUser indicates this is a user account + AccountTypeUser AccountType = 1 + // AccountTypeGuest indicates this is a guest account + AccountTypeGuest AccountType = 2 +) diff --git a/userapi/internal/api.go b/userapi/internal/api.go new file mode 100644 index 000000000..ae021f575 --- /dev/null +++ b/userapi/internal/api.go @@ -0,0 +1,218 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/appservice/types" + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/gomatrixserverlib" +) + +type UserInternalAPI struct { + AccountDB accounts.Database + DeviceDB devices.Database + ServerName gomatrixserverlib.ServerName + // AppServices is the list of all registered AS + AppServices []config.ApplicationService +} + +func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + if req.AccountType == api.AccountTypeGuest { + acc, err := a.AccountDB.CreateGuestAccount(ctx) + if err != nil { + return err + } + res.AccountCreated = true + res.Account = acc + return nil + } + acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID) + if err != nil { + if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists + switch req.OnConflict { + case api.ConflictUpdate: + break + case api.ConflictAbort: + return &api.ErrorConflict{ + Message: err.Error(), + } + } + } + // account already exists + res.AccountCreated = false + res.Account = &api.Account{ + AppServiceID: req.AppServiceID, + Localpart: req.Localpart, + ServerName: a.ServerName, + UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + } + return nil + } + res.AccountCreated = true + res.Account = acc + return nil +} +func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { + dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName) + if err != nil { + return err + } + res.DeviceCreated = true + res.Device = dev + return nil +} + +func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + } + prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err + } + res.UserExists = true + res.AvatarURL = prof.AvatarURL + res.DisplayName = prof.DisplayName + return nil +} + +func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) + } + devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local) + if err != nil { + return err + } + res.Devices = devs + return nil +} + +func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) + } + if req.DataType != "" { + var event *gomatrixserverlib.ClientEvent + event, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + if err != nil { + return err + } + if event != nil { + if req.RoomID != "" { + res.RoomAccountData = make(map[string][]gomatrixserverlib.ClientEvent) + res.RoomAccountData[req.RoomID] = []gomatrixserverlib.ClientEvent{*event} + } else { + res.GlobalAccountData = append(res.GlobalAccountData, *event) + } + } + return nil + } + global, rooms, err := a.AccountDB.GetAccountData(ctx, local) + if err != nil { + return err + } + res.RoomAccountData = rooms + res.GlobalAccountData = global + return nil +} + +func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAccessTokenRequest, res *api.QueryAccessTokenResponse) error { + if req.AppServiceUserID != "" { + appServiceDevice, err := a.queryAppServiceToken(ctx, req.AccessToken, req.AppServiceUserID) + res.Device = appServiceDevice + res.Err = err + return nil + } + device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err + } + res.Device = device + return nil +} + +// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem +// creating a 'device'. +func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { + // Search for app service with given access_token + var appService *config.ApplicationService + for _, as := range a.AppServices { + if as.ASToken == token { + appService = &as + break + } + } + if appService == nil { + return nil, nil + } + + // Create a dummy device for AS user + dev := api.Device{ + // Use AS dummy device ID + ID: types.AppServiceDeviceID, + // AS dummy device has AS's token. + AccessToken: token, + } + + localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) + if err != nil { + return nil, err + } + + if localpart != "" { // AS is masquerading as another user + // Verify that the user is registered + account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart) + // Verify that account exists & appServiceID matches + if err == nil && account.AppServiceID == appService.ID { + // Set the userID of dummy device + dev.UserID = appServiceUserID + return &dev, nil + } + return nil, &api.ErrorForbidden{Message: "appservice has not registered this user"} + } + + // AS is not masquerading as any user, so use AS's sender_localpart + dev.UserID = appService.SenderLocalpart + return &dev, nil +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go new file mode 100644 index 000000000..0e9628c58 --- /dev/null +++ b/userapi/inthttp/client.go @@ -0,0 +1,120 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "context" + "errors" + "net/http" + + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/opentracing/opentracing-go" +) + +// HTTP paths for the internal HTTP APIs +const ( + PerformDeviceCreationPath = "/userapi/performDeviceCreation" + PerformAccountCreationPath = "/userapi/performAccountCreation" + + QueryProfilePath = "/userapi/queryProfile" + QueryAccessTokenPath = "/userapi/queryAccessToken" + QueryDevicesPath = "/userapi/queryDevices" + QueryAccountDataPath = "/userapi/queryAccountData" +) + +// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. +// If httpClient is nil an error is returned +func NewUserAPIClient( + apiURL string, + httpClient *http.Client, +) (api.UserInternalAPI, error) { + if httpClient == nil { + return nil, errors.New("NewUserAPIClient: httpClient is ") + } + return &httpUserInternalAPI{ + apiURL: apiURL, + httpClient: httpClient, + }, nil +} + +type httpUserInternalAPI struct { + apiURL string + httpClient *http.Client +} + +func (h *httpUserInternalAPI) PerformAccountCreation( + ctx context.Context, + request *api.PerformAccountCreationRequest, + response *api.PerformAccountCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformAccountCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformDeviceCreation( + ctx context.Context, + request *api.PerformDeviceCreationRequest, + response *api.PerformDeviceCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformDeviceCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryProfile( + ctx context.Context, + request *api.QueryProfileRequest, + response *api.QueryProfileResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryProfile") + defer span.Finish() + + apiURL := h.apiURL + QueryProfilePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryAccessToken( + ctx context.Context, + request *api.QueryAccessTokenRequest, + response *api.QueryAccessTokenResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccessToken") + defer span.Finish() + + apiURL := h.apiURL + QueryAccessTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDevices") + defer span.Finish() + + apiURL := h.apiURL + QueryDevicesPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccountData") + defer span.Finish() + + apiURL := h.apiURL + QueryAccountDataPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go new file mode 100644 index 000000000..8f3be7738 --- /dev/null +++ b/userapi/inthttp/server.go @@ -0,0 +1,106 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { + internalAPIMux.Handle(PerformAccountCreationPath, + httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformAccountCreationRequest{} + response := api.PerformAccountCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformAccountCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformDeviceCreationPath, + httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformDeviceCreationRequest{} + response := api.PerformDeviceCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformDeviceCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryProfilePath, + httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { + request := api.QueryProfileRequest{} + response := api.QueryProfileResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryProfile(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryAccessTokenPath, + httputil.MakeInternalAPI("queryAccessToken", func(req *http.Request) util.JSONResponse { + request := api.QueryAccessTokenRequest{} + response := api.QueryAccessTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryAccessToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryDevicesPath, + httputil.MakeInternalAPI("queryDevices", func(req *http.Request) util.JSONResponse { + request := api.QueryDevicesRequest{} + response := api.QueryDevicesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryDevices(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryAccountDataPath, + httputil.MakeInternalAPI("queryAccountData", func(req *http.Request) util.JSONResponse { + request := api.QueryAccountDataRequest{} + response := api.QueryAccountDataResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryAccountData(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/clientapi/auth/storage/accounts/interface.go b/userapi/storage/accounts/interface.go similarity index 88% rename from clientapi/auth/storage/accounts/interface.go rename to userapi/storage/accounts/interface.go index 4d1941a23..13e3e2895 100644 --- a/clientapi/auth/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -20,26 +20,31 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { internal.PartitionStorer - GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error) + GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error) - CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) + CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error) + CreateGuestAccount(ctx context.Context) (*api.Account, error) UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) + // GetAccountDataByType returns account data matching a given + // localpart, room ID and type. + // If no account data could be found, returns nil + // Returns an error if there was an issue with the retrieval GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error) GetNewNumericLocalpart(ctx context.Context) (int64, error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) @@ -49,7 +54,7 @@ type Database interface { GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error) + GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/clientapi/auth/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go similarity index 99% rename from clientapi/auth/storage/accounts/postgres/account_data_table.go rename to userapi/storage/accounts/postgres/account_data_table.go index 2e044b367..2f16c5c02 100644 --- a/clientapi/auth/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/gomatrixserverlib" ) diff --git a/clientapi/auth/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go similarity index 96% rename from clientapi/auth/storage/accounts/postgres/accounts_table.go rename to userapi/storage/accounts/postgres/accounts_table.go index 85c1938a1..931ffb73d 100644 --- a/clientapi/auth/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -19,8 +19,8 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" @@ -92,7 +92,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // on success. func (s *accountsStatements) insertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := txn.Stmt(s.insertAccountStmt) @@ -106,7 +106,7 @@ func (s *accountsStatements) insertAccount( return nil, err } - return &authtypes.Account{ + return &api.Account{ Localpart: localpart, UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, @@ -123,9 +123,9 @@ func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectAccountByLocalpart( ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var appserviceIDPtr sql.NullString - var acc authtypes.Account + var acc api.Account stmt := s.selectAccountByLocalpartStmt err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) diff --git a/clientapi/auth/storage/accounts/postgres/filter_table.go b/userapi/storage/accounts/postgres/filter_table.go similarity index 100% rename from clientapi/auth/storage/accounts/postgres/filter_table.go rename to userapi/storage/accounts/postgres/filter_table.go diff --git a/clientapi/auth/storage/accounts/postgres/membership_table.go b/userapi/storage/accounts/postgres/membership_table.go similarity index 99% rename from clientapi/auth/storage/accounts/postgres/membership_table.go rename to userapi/storage/accounts/postgres/membership_table.go index d006e916a..623530acc 100644 --- a/clientapi/auth/storage/accounts/postgres/membership_table.go +++ b/userapi/storage/accounts/postgres/membership_table.go @@ -18,10 +18,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" - "github.com/lib/pq" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal" ) const membershipSchema = ` diff --git a/clientapi/auth/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go similarity index 100% rename from clientapi/auth/storage/accounts/postgres/profile_table.go rename to userapi/storage/accounts/postgres/profile_table.go diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go similarity index 94% rename from clientapi/auth/storage/accounts/postgres/storage.go rename to userapi/storage/accounts/postgres/storage.go index 4be5dca94..2b88cb70a 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -21,8 +21,8 @@ import ( "strconv" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" @@ -33,7 +33,7 @@ import ( // Database represents an account database type Database struct { db *sql.DB - internal.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements memberships membershipStatements @@ -44,13 +44,13 @@ type Database struct { } // NewDatabase creates a new accounts and profiles database -func NewDatabase(dataSourceName string, dbProperties internal.DbProperties, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (*Database, error) { var db *sql.DB var err error if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } - partitions := internal.PartitionOffsetStatements{} + partitions := sqlutil.PartitionOffsetStatements{} if err = partitions.Prepare(db, "account"); err != nil { return nil, err } @@ -85,7 +85,7 @@ func NewDatabase(dataSourceName string, dbProperties internal.DbProperties, serv // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( ctx context.Context, localpart, plaintextPassword string, -) (*authtypes.Account, error) { +) (*api.Account, error) { hash, err := d.accounts.selectPasswordHash(ctx, localpart) if err != nil { return nil, err @@ -122,8 +122,8 @@ func (d *Database) SetDisplayName( // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) if err != nil { @@ -138,11 +138,11 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, ErrUserExists. +// account already exists, it will return nil, sqlutil.ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, -) (acc *authtypes.Account, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { +) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err }) @@ -151,7 +151,7 @@ func (d *Database) CreateAccount( func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var err error // Generate a password hash if this is not a password-less user @@ -163,8 +163,8 @@ func (d *Database) createAccount( } } if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { - if internal.IsUniqueConstraintViolationErr(err) { - return nil, internal.ErrUserExists + if sqlutil.IsUniqueConstraintViolationErr(err) { + return nil, sqlutil.ErrUserExists } return nil, err } @@ -210,7 +210,7 @@ func (d *Database) removeMembershipsByEventIDs( func (d *Database) UpdateMemberships( ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { return err } @@ -297,7 +297,7 @@ func (d *Database) newMembership( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType, content string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) } @@ -348,7 +348,7 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { user, err := d.threepids.selectLocalpartForThreePID( ctx, txn, threepid, medium, ) @@ -428,6 +428,6 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } diff --git a/clientapi/auth/storage/accounts/postgres/threepid_table.go b/userapi/storage/accounts/postgres/threepid_table.go similarity index 95% rename from clientapi/auth/storage/accounts/postgres/threepid_table.go rename to userapi/storage/accounts/postgres/threepid_table.go index 0b12b5c5b..7de96350c 100644 --- a/clientapi/auth/storage/accounts/postgres/threepid_table.go +++ b/userapi/storage/accounts/postgres/threepid_table.go @@ -18,7 +18,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -82,7 +82,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { func (s *threepidStatements) selectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { - stmt := internal.TxStmt(txn, s.selectLocalpartForThreePIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) if err == sql.ErrNoRows { return "", nil @@ -117,7 +117,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) insertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { - stmt := internal.TxStmt(txn, s.insertThreePIDStmt) + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium, localpart) return } diff --git a/clientapi/auth/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go similarity index 100% rename from clientapi/auth/storage/accounts/sqlite3/account_data_table.go rename to userapi/storage/accounts/sqlite3/account_data_table.go diff --git a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go similarity index 96% rename from clientapi/auth/storage/accounts/sqlite3/accounts_table.go rename to userapi/storage/accounts/sqlite3/accounts_table.go index fd6a09cde..768f536dd 100644 --- a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -19,8 +19,8 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" @@ -90,7 +90,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // on success. func (s *accountsStatements) insertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt @@ -104,7 +104,7 @@ func (s *accountsStatements) insertAccount( return nil, err } - return &authtypes.Account{ + return &api.Account{ Localpart: localpart, UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, @@ -121,9 +121,9 @@ func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectAccountByLocalpart( ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var appserviceIDPtr sql.NullString - var acc authtypes.Account + var acc api.Account stmt := s.selectAccountByLocalpartStmt err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) diff --git a/clientapi/auth/storage/accounts/sqlite3/constraint.go b/userapi/storage/accounts/sqlite3/constraint.go similarity index 100% rename from clientapi/auth/storage/accounts/sqlite3/constraint.go rename to userapi/storage/accounts/sqlite3/constraint.go diff --git a/clientapi/auth/storage/accounts/sqlite3/constraint_wasm.go b/userapi/storage/accounts/sqlite3/constraint_wasm.go similarity index 100% rename from clientapi/auth/storage/accounts/sqlite3/constraint_wasm.go rename to userapi/storage/accounts/sqlite3/constraint_wasm.go diff --git a/clientapi/auth/storage/accounts/sqlite3/filter_table.go b/userapi/storage/accounts/sqlite3/filter_table.go similarity index 100% rename from clientapi/auth/storage/accounts/sqlite3/filter_table.go rename to userapi/storage/accounts/sqlite3/filter_table.go diff --git a/clientapi/auth/storage/accounts/sqlite3/membership_table.go b/userapi/storage/accounts/sqlite3/membership_table.go similarity index 98% rename from clientapi/auth/storage/accounts/sqlite3/membership_table.go rename to userapi/storage/accounts/sqlite3/membership_table.go index 90e16fb01..67958f27d 100644 --- a/clientapi/auth/storage/accounts/sqlite3/membership_table.go +++ b/userapi/storage/accounts/sqlite3/membership_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) const membershipSchema = ` @@ -95,7 +96,7 @@ func (s *membershipStatements) insertMembership( func (s *membershipStatements) deleteMembershipsByEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, ) (err error) { - sqlStr := strings.Replace(deleteMembershipsByEventIDsSQL, "($1)", internal.QueryVariadic(len(eventIDs)), 1) + sqlStr := strings.Replace(deleteMembershipsByEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) iEventIDs := make([]interface{}, len(eventIDs)) for i, e := range eventIDs { iEventIDs[i] = e diff --git a/clientapi/auth/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go similarity index 100% rename from clientapi/auth/storage/accounts/sqlite3/profile_table.go rename to userapi/storage/accounts/sqlite3/profile_table.go diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go similarity index 95% rename from clientapi/auth/storage/accounts/sqlite3/storage.go rename to userapi/storage/accounts/sqlite3/storage.go index 31426e471..4dd755a70 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -22,8 +22,8 @@ import ( "sync" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" // Import the sqlite3 database driver. @@ -32,7 +32,7 @@ import ( // Database represents an account database type Database struct { db *sql.DB - internal.PartitionOffsetStatements + sqlutil.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements memberships membershipStatements @@ -52,10 +52,10 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err != nil { return nil, err } - if db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil); err != nil { + if db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } - partitions := internal.PartitionOffsetStatements{} + partitions := sqlutil.PartitionOffsetStatements{} if err = partitions.Prepare(db, "account"); err != nil { return nil, err } @@ -90,7 +90,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( ctx context.Context, localpart, plaintextPassword string, -) (*authtypes.Account, error) { +) (*api.Account, error) { hash, err := d.accounts.selectPasswordHash(ctx, localpart) if err != nil { return nil, err @@ -127,8 +127,8 @@ func (d *Database) SetDisplayName( // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { // We need to lock so we sequentially create numeric localparts. If we don't, two calls to // this function will cause the same number to be selected and one will fail with 'database is locked' // when the first txn upgrades to a write txn. @@ -153,8 +153,8 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, -) (acc *authtypes.Account, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { +) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err }) @@ -163,7 +163,7 @@ func (d *Database) CreateAccount( func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var err error // Generate a password hash if this is not a password-less user hash := "" @@ -175,7 +175,7 @@ func (d *Database) createAccount( } if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { if isConstraintError(err) { - return nil, internal.ErrUserExists + return nil, sqlutil.ErrUserExists } return nil, err } @@ -221,7 +221,7 @@ func (d *Database) removeMembershipsByEventIDs( func (d *Database) UpdateMemberships( ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { return err } @@ -308,7 +308,7 @@ func (d *Database) newMembership( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType, content string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) } @@ -359,7 +359,7 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { user, err := d.threepids.selectLocalpartForThreePID( ctx, txn, threepid, medium, ) @@ -439,6 +439,6 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } diff --git a/clientapi/auth/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go similarity index 95% rename from clientapi/auth/storage/accounts/sqlite3/threepid_table.go rename to userapi/storage/accounts/sqlite3/threepid_table.go index f78a5ffcd..0200dee7f 100644 --- a/clientapi/auth/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -82,7 +83,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { func (s *threepidStatements) selectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { - stmt := internal.TxStmt(txn, s.selectLocalpartForThreePIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) if err == sql.ErrNoRows { return "", nil @@ -117,7 +118,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) insertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { - stmt := internal.TxStmt(txn, s.insertThreePIDStmt) + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium, localpart) return } diff --git a/clientapi/auth/storage/accounts/storage.go b/userapi/storage/accounts/storage.go similarity index 79% rename from clientapi/auth/storage/accounts/storage.go rename to userapi/storage/accounts/storage.go index 37126b30b..87f626bf9 100644 --- a/clientapi/auth/storage/accounts/storage.go +++ b/userapi/storage/accounts/storage.go @@ -19,15 +19,15 @@ package accounts import ( "net/url" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" + "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewDatabase(dataSourceName string, dbProperties internal.DbProperties, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return postgres.NewDatabase(dataSourceName, dbProperties, serverName) diff --git a/clientapi/auth/storage/accounts/storage_wasm.go b/userapi/storage/accounts/storage_wasm.go similarity index 87% rename from clientapi/auth/storage/accounts/storage_wasm.go rename to userapi/storage/accounts/storage_wasm.go index 81e77cf79..692567059 100644 --- a/clientapi/auth/storage/accounts/storage_wasm.go +++ b/userapi/storage/accounts/storage_wasm.go @@ -18,14 +18,14 @@ import ( "fmt" "net/url" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) func NewDatabase( dataSourceName string, - dbProperties internal.DbProperties, // nolint:unparam + dbProperties sqlutil.DbProperties, // nolint:unparam serverName gomatrixserverlib.ServerName, ) (Database, error) { uri, err := url.Parse(dataSourceName) diff --git a/clientapi/auth/storage/devices/interface.go b/userapi/storage/devices/interface.go similarity index 64% rename from clientapi/auth/storage/devices/interface.go rename to userapi/storage/devices/interface.go index 95291e4a7..4bdb57850 100644 --- a/clientapi/auth/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -17,14 +17,20 @@ package devices import ( "context" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/userapi/api" ) type Database interface { - GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error) - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error) + GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) + GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + // CreateDevice makes a new device associated with the given user ID localpart. + // If there is already a device with the same device ID for this user, that access token will be revoked + // and replaced with the given accessToken. If the given accessToken is already in use for another device, + // an error will be returned. + // If no device ID is given one is generated. + // Returns the device on success. + CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *api.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error diff --git a/clientapi/auth/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go similarity index 93% rename from clientapi/auth/storage/devices/postgres/devices_table.go rename to userapi/storage/devices/postgres/devices_table.go index c84e83d3d..1d036d1b3 100644 --- a/clientapi/auth/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -20,9 +20,10 @@ import ( "time" "github.com/lib/pq" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -134,14 +135,14 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN func (s *devicesStatements) insertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, -) (*authtypes.Device, error) { +) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - stmt := internal.TxStmt(txn, s.insertDeviceStmt) + stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil { return nil, err } - return &authtypes.Device{ + return &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, s.serverName), AccessToken: accessToken, @@ -153,7 +154,7 @@ func (s *devicesStatements) insertDevice( func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { - stmt := internal.TxStmt(txn, s.deleteDeviceStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) _, err := stmt.ExecContext(ctx, id, localpart) return err } @@ -163,7 +164,7 @@ func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { - stmt := internal.TxStmt(txn, s.deleteDevicesStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) return err } @@ -173,7 +174,7 @@ func (s *devicesStatements) deleteDevices( func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { - stmt := internal.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) _, err := stmt.ExecContext(ctx, localpart) return err } @@ -181,15 +182,15 @@ func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) updateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { - stmt := internal.TxStmt(txn, s.updateDeviceNameStmt) + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) return err } func (s *devicesStatements) selectDeviceByToken( ctx context.Context, accessToken string, -) (*authtypes.Device, error) { - var dev authtypes.Device +) (*api.Device, error) { + var dev api.Device var localpart string stmt := s.selectDeviceByTokenStmt err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) @@ -204,8 +205,8 @@ func (s *devicesStatements) selectDeviceByToken( // localpart and deviceID func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, -) (*authtypes.Device, error) { - var dev authtypes.Device +) (*api.Device, error) { + var dev api.Device stmt := s.selectDeviceByIDStmt err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) if err == nil { @@ -217,8 +218,8 @@ func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) selectDevicesByLocalpart( ctx context.Context, localpart string, -) ([]authtypes.Device, error) { - devices := []authtypes.Device{} +) ([]api.Device, error) { + devices := []api.Device{} rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) @@ -228,7 +229,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed") for rows.Next() { - var dev authtypes.Device + var dev api.Device var id, displayname sql.NullString err = rows.Scan(&id, &displayname) if err != nil { diff --git a/clientapi/auth/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go similarity index 87% rename from clientapi/auth/storage/devices/postgres/storage.go rename to userapi/storage/devices/postgres/storage.go index e54e7c5d7..801657bd5 100644 --- a/clientapi/auth/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -20,9 +20,8 @@ import ( "database/sql" "encoding/base64" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -36,7 +35,7 @@ type Database struct { } // NewDatabase creates a new device database -func NewDatabase(dataSourceName string, dbProperties internal.DbProperties, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (*Database, error) { var db *sql.DB var err error if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { @@ -53,7 +52,7 @@ func NewDatabase(dataSourceName string, dbProperties internal.DbProperties, serv // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByAccessToken( ctx context.Context, token string, -) (*authtypes.Device, error) { +) (*api.Device, error) { return d.devices.selectDeviceByToken(ctx, token) } @@ -61,14 +60,14 @@ func (d *Database) GetDeviceByAccessToken( // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( ctx context.Context, localpart, deviceID string, -) (*authtypes.Device, error) { +) (*api.Device, error) { return d.devices.selectDeviceByID(ctx, localpart, deviceID) } // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, -) ([]authtypes.Device, error) { +) ([]api.Device, error) { return d.devices.selectDevicesByLocalpart(ctx, localpart) } @@ -81,9 +80,9 @@ func (d *Database) GetDevicesByLocalpart( func (d *Database) CreateDevice( ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, -) (dev *authtypes.Device, returnErr error) { +) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { @@ -103,7 +102,7 @@ func (d *Database) CreateDevice( return } - returnErr = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) return err @@ -133,7 +132,7 @@ func generateDeviceID() (string, error) { func (d *Database) UpdateDevice( ctx context.Context, localpart, deviceID string, displayName *string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) }) } @@ -145,7 +144,7 @@ func (d *Database) UpdateDevice( func (d *Database) RemoveDevice( ctx context.Context, deviceID, localpart string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { return err } @@ -160,7 +159,7 @@ func (d *Database) RemoveDevice( func (d *Database) RemoveDevices( ctx context.Context, localpart string, devices []string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { return err } @@ -174,7 +173,7 @@ func (d *Database) RemoveDevices( func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } diff --git a/clientapi/auth/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go similarity index 89% rename from clientapi/auth/storage/devices/sqlite3/devices_table.go rename to userapi/storage/devices/sqlite3/devices_table.go index 49f3eaed7..07ea5dca3 100644 --- a/clientapi/auth/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -20,9 +20,9 @@ import ( "strings" "time" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -125,11 +125,11 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN func (s *devicesStatements) insertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, -) (*authtypes.Device, error) { +) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - countStmt := internal.TxStmt(txn, s.selectDevicesCountStmt) - insertStmt := internal.TxStmt(txn, s.insertDeviceStmt) + countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { return nil, err } @@ -137,7 +137,7 @@ func (s *devicesStatements) insertDevice( if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { return nil, err } - return &authtypes.Device{ + return &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, s.serverName), AccessToken: accessToken, @@ -148,7 +148,7 @@ func (s *devicesStatements) insertDevice( func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { - stmt := internal.TxStmt(txn, s.deleteDeviceStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) _, err := stmt.ExecContext(ctx, id, localpart) return err } @@ -156,12 +156,12 @@ func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($1)", internal.QueryVariadic(len(devices)), 1) + orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1) prep, err := s.db.Prepare(orig) if err != nil { return err } - stmt := internal.TxStmt(txn, prep) + stmt := sqlutil.TxStmt(txn, prep) params := make([]interface{}, len(devices)+1) params[0] = localpart for i, v := range devices { @@ -175,7 +175,7 @@ func (s *devicesStatements) deleteDevices( func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { - stmt := internal.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) _, err := stmt.ExecContext(ctx, localpart) return err } @@ -183,15 +183,15 @@ func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) updateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { - stmt := internal.TxStmt(txn, s.updateDeviceNameStmt) + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) return err } func (s *devicesStatements) selectDeviceByToken( ctx context.Context, accessToken string, -) (*authtypes.Device, error) { - var dev authtypes.Device +) (*api.Device, error) { + var dev api.Device var localpart string stmt := s.selectDeviceByTokenStmt err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) @@ -206,8 +206,8 @@ func (s *devicesStatements) selectDeviceByToken( // localpart and deviceID func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, -) (*authtypes.Device, error) { - var dev authtypes.Device +) (*api.Device, error) { + var dev api.Device stmt := s.selectDeviceByIDStmt err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) if err == nil { @@ -219,8 +219,8 @@ func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) selectDevicesByLocalpart( ctx context.Context, localpart string, -) ([]authtypes.Device, error) { - devices := []authtypes.Device{} +) ([]api.Device, error) { + devices := []api.Device{} rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) @@ -229,7 +229,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( } for rows.Next() { - var dev authtypes.Device + var dev api.Device var id, displayname sql.NullString err = rows.Scan(&id, &displayname) if err != nil { diff --git a/clientapi/auth/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go similarity index 88% rename from clientapi/auth/storage/devices/sqlite3/storage.go rename to userapi/storage/devices/sqlite3/storage.go index e05a53b46..f248abda4 100644 --- a/clientapi/auth/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -20,9 +20,8 @@ import ( "database/sql" "encoding/base64" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" _ "github.com/mattn/go-sqlite3" @@ -45,7 +44,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err != nil { return nil, err } - if db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil); err != nil { + if db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } d := devicesStatements{} @@ -59,7 +58,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByAccessToken( ctx context.Context, token string, -) (*authtypes.Device, error) { +) (*api.Device, error) { return d.devices.selectDeviceByToken(ctx, token) } @@ -67,14 +66,14 @@ func (d *Database) GetDeviceByAccessToken( // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( ctx context.Context, localpart, deviceID string, -) (*authtypes.Device, error) { +) (*api.Device, error) { return d.devices.selectDeviceByID(ctx, localpart, deviceID) } // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, -) ([]authtypes.Device, error) { +) ([]api.Device, error) { return d.devices.selectDevicesByLocalpart(ctx, localpart) } @@ -87,9 +86,9 @@ func (d *Database) GetDevicesByLocalpart( func (d *Database) CreateDevice( ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, -) (dev *authtypes.Device, returnErr error) { +) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { @@ -109,7 +108,7 @@ func (d *Database) CreateDevice( return } - returnErr = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) return err @@ -139,7 +138,7 @@ func generateDeviceID() (string, error) { func (d *Database) UpdateDevice( ctx context.Context, localpart, deviceID string, displayName *string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) }) } @@ -151,7 +150,7 @@ func (d *Database) UpdateDevice( func (d *Database) RemoveDevice( ctx context.Context, deviceID, localpart string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { return err } @@ -166,7 +165,7 @@ func (d *Database) RemoveDevice( func (d *Database) RemoveDevices( ctx context.Context, localpart string, devices []string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { return err } @@ -180,7 +179,7 @@ func (d *Database) RemoveDevices( func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, ) error { - return internal.WithTransaction(d.db, func(txn *sql.Tx) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } diff --git a/clientapi/auth/storage/devices/storage.go b/userapi/storage/devices/storage.go similarity index 79% rename from clientapi/auth/storage/devices/storage.go rename to userapi/storage/devices/storage.go index c132598bd..e094d202a 100644 --- a/clientapi/auth/storage/devices/storage.go +++ b/userapi/storage/devices/storage.go @@ -19,15 +19,15 @@ package devices import ( "net/url" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" + "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewDatabase(dataSourceName string, dbProperties internal.DbProperties, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return postgres.NewDatabase(dataSourceName, dbProperties, serverName) diff --git a/clientapi/auth/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go similarity index 87% rename from clientapi/auth/storage/devices/storage_wasm.go rename to userapi/storage/devices/storage_wasm.go index 14c19e74b..a5a515eff 100644 --- a/clientapi/auth/storage/devices/storage_wasm.go +++ b/userapi/storage/devices/storage_wasm.go @@ -18,14 +18,14 @@ import ( "fmt" "net/url" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) func NewDatabase( dataSourceName string, - dbProperties internal.DbProperties, // nolint:unparam + dbProperties sqlutil.DbProperties, // nolint:unparam serverName gomatrixserverlib.ServerName, ) (Database, error) { uri, err := url.Parse(dataSourceName) diff --git a/userapi/userapi.go b/userapi/userapi.go new file mode 100644 index 000000000..7aadec06a --- /dev/null +++ b/userapi/userapi.go @@ -0,0 +1,45 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userapi + +import ( + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/gomatrixserverlib" +) + +// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions +// on the given input API. +func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { + inthttp.AddRoutes(router, intAPI) +} + +// NewInternalAPI returns a concerete implementation of the internal API. Callers +// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. +func NewInternalAPI(accountDB accounts.Database, deviceDB devices.Database, + serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService) api.UserInternalAPI { + + return &internal.UserInternalAPI{ + AccountDB: accountDB, + DeviceDB: deviceDB, + ServerName: serverName, + AppServices: appServices, + } +} diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go new file mode 100644 index 000000000..163b10ec7 --- /dev/null +++ b/userapi/userapi_test.go @@ -0,0 +1,112 @@ +package userapi_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + serverName = gomatrixserverlib.ServerName("example.com") +) + +func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database, devices.Database) { + accountDB, err := accounts.NewDatabase("file::memory:", nil, serverName) + if err != nil { + t.Fatalf("failed to create account DB: %s", err) + } + deviceDB, err := devices.NewDatabase("file::memory:", nil, serverName) + if err != nil { + t.Fatalf("failed to create device DB: %s", err) + } + + return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil), accountDB, deviceDB +} + +func TestQueryProfile(t *testing.T) { + aliceAvatarURL := "mxc://example.com/alice" + aliceDisplayName := "Alice" + userAPI, accountDB, _ := MustMakeInternalAPI(t) + _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") + if err != nil { + t.Fatalf("failed to make account: %s", err) + } + if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + t.Fatalf("failed to set avatar url: %s", err) + } + if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + t.Fatalf("failed to set display name: %s", err) + } + + testCases := []struct { + req api.QueryProfileRequest + wantRes api.QueryProfileResponse + wantErr error + }{ + { + req: api.QueryProfileRequest{ + UserID: fmt.Sprintf("@alice:%s", serverName), + }, + wantRes: api.QueryProfileResponse{ + UserExists: true, + AvatarURL: aliceAvatarURL, + DisplayName: aliceDisplayName, + }, + }, + { + req: api.QueryProfileRequest{ + UserID: fmt.Sprintf("@bob:%s", serverName), + }, + wantRes: api.QueryProfileResponse{ + UserExists: false, + }, + }, + { + req: api.QueryProfileRequest{ + UserID: "@alice:wrongdomain.com", + }, + wantErr: fmt.Errorf("wrong domain"), + }, + } + + runCases := func(testAPI api.UserInternalAPI) { + for _, tc := range testCases { + var gotRes api.QueryProfileResponse + gotErr := testAPI.QueryProfile(context.TODO(), &tc.req, &gotRes) + if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil { + t.Errorf("QueryProfile error, got %s want %s", gotErr, tc.wantErr) + continue + } + if !reflect.DeepEqual(tc.wantRes, gotRes) { + t.Errorf("QueryProfile response got %+v want %+v", gotRes, tc.wantRes) + } + } + } + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, userAPI) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) + if err != nil { + t.Fatalf("failed to create HTTP client") + } + runCases(httpAPI) + }) + t.Run("Monolith", func(t *testing.T) { + runCases(userAPI) + }) +}