diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 53b1be3cb..4a7afc58e 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -30,13 +30,14 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" ) +var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") + func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { - return util.MatrixErrorResponse( - http.StatusForbidden, - string(spec.ErrorForbidden), - "Registration via tokens is not enabled on this homeserver", - ) + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"), + } } request := struct { Token string `json:"token"` @@ -46,11 +47,10 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u }{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorBadJSON), - "Failed to decode request body:", - ) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), + } } token := request.Token @@ -65,43 +65,43 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } // token not present in request body. Hence, generate a random token. if !(length > 0 && length <= 64) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "length must be greater than zero and not greater than 64") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("length must be greater than zero and not greater than 64"), + } } token = generateRandomToken(int(length)) } if len(token) > 64 { //Token present in request body, but is too long. - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "token must not be longer than 64") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must not be longer than 64"), + } } - isTokenValid, _ := regexp.MatchString("^[[:ascii:][:digit:]_]*$", token) + isTokenValid := validRegistrationTokenRegex.Match([]byte(token)) if !isTokenValid { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "token must consist only of characters matched by the regex [A-Za-z0-9-_]") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"), + } } // At this point, we have a valid token, either through request body or through random generation. if usesAllowed < 0 { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "uses_allowed must be a non-negative integer or null") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } } if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "expiry_time must not be in the past") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } } pending := int32(0) completed := int32(0) @@ -115,17 +115,16 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken) if err != nil { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorUnknown), - err.Error(), - ) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } } if !created { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - fmt.Sprintf("Token alreaady exists: %s", token)) + return util.JSONResponse{ + Code: http.StatusConflict, + JSON: fmt.Sprintf("Token already exists: %s", token), + } } return util.JSONResponse{ Code: 200, @@ -166,21 +165,19 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA returnAll = false validValue, err := strconv.ParseBool(validQuery[0]) if err != nil { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "invalid 'valid' query parameter", - ) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("invalid 'valid' query parameter"), + } } valid = validValue } tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid) if err != nil { - return util.MatrixErrorResponse( - http.StatusInternalServerError, - string(spec.ErrorUnknown), - "error fetching registration tokens", - ) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.ErrorUnknown, + } } return util.JSONResponse{ Code: 200, @@ -205,11 +202,10 @@ func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI tokenText := vars["token"] token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText) if err != nil { - return util.MatrixErrorResponse( - http.StatusNotFound, - string(spec.ErrorUnknown), - fmt.Sprintf("token: %s not found", tokenText), - ) + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } } return util.JSONResponse{ Code: 200, @@ -225,11 +221,10 @@ func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, user tokenText := vars["token"] err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText) if err != nil { - return util.MatrixErrorResponse( - http.StatusNotFound, - string(spec.ErrorUnknown), - fmt.Sprintf("token: %s not found", tokenText), - ) + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } } return util.JSONResponse{ Code: 200, @@ -244,12 +239,11 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user } tokenText := vars["token"] request := make(map[string]interface{}) - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorBadJSON), - "Failed to decode request body:", - ) + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), + } } newAttributes := make(map[string]interface{}) usesAllowed, ok := request["uses_allowed"] @@ -258,11 +252,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic // Synapse's behaviour of updating the field if and only if it is present in request body. if !(usesAllowed == nil || int32(usesAllowed.(float64)) >= 0) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "uses_allowed must be a non-negative integer or null", - ) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } } newAttributes["usesAllowed"] = usesAllowed } @@ -272,11 +265,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic // Synapse's behaviour of updating the field if and only if it is present in request body. if !(expiryTime == nil || int64(expiryTime.(float64)) > time.Now().UnixNano()/int64(time.Millisecond)) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "expiry_time must be in the future", - ) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } } newAttributes["expiryTime"] = expiryTime } @@ -286,11 +278,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user } updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes) if err != nil { - return util.MatrixErrorResponse( - http.StatusNotFound, - string(spec.ErrorUnknown), - fmt.Sprintf("token: %s not found", tokenText), - ) + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } } return util.JSONResponse{ Code: 200, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 79e628f3f..ab4aefddd 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -176,19 +176,20 @@ func Setup( dendriteAdminRouter.Handle("/admin/registrationTokens/{token}", httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - if req.Method == http.MethodGet { + switch req.Method { + case http.MethodGet: return AdminGetRegistrationToken(req, cfg, userAPI) - } else if req.Method == http.MethodPut { + case http.MethodPut: return AdminUpdateRegistrationToken(req, cfg, userAPI) - } else if req.Method == http.MethodDelete { + case http.MethodDelete: return AdminDeleteRegistrationToken(req, cfg, userAPI) + default: + return util.MatrixErrorResponse( + 404, + string(spec.ErrorNotFound), + "unknown method", + ) } - return util.MatrixErrorResponse( - 404, - string(spec.ErrorNotFound), - "unknown method", - ) - }), ).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions) @@ -196,7 +197,7 @@ func Setup( httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminEvacuateRoom(req, rsAPI) }), - ).Methods(http.MethodGet, http.MethodOptions) + ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go deleted file mode 100644 index 66eb88f87..000000000 --- a/cmd/dendrite/main.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "flag" - "time" - - "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib/fclient" - "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/setup" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/mscs" - "github.com/matrix-org/dendrite/userapi" -) - -var ( - unixSocket = flag.String("unix-socket", "", - "EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)", - ) - unixSocketPermission = flag.String("unix-socket-permission", "755", - "EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server (in chmod format like 755)", - ) - httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") - httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") - certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") - keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") -) - -func main() { - cfg := setup.ParseFlags(true) - httpAddr := config.ServerAddress{} - httpsAddr := config.ServerAddress{} - if *unixSocket == "" { - http, err := config.HTTPAddress("http://" + *httpBindAddr) - if err != nil { - logrus.WithError(err).Fatalf("Failed to parse http address") - } - httpAddr = http - https, err := config.HTTPAddress("https://" + *httpsBindAddr) - if err != nil { - logrus.WithError(err).Fatalf("Failed to parse https address") - } - httpsAddr = https - } else { - socket, err := config.UnixSocketAddress(*unixSocket, *unixSocketPermission) - if err != nil { - logrus.WithError(err).Fatalf("Failed to parse unix socket") - } - httpAddr = socket - } - - configErrors := &config.ConfigErrors{} - cfg.Verify(configErrors) - if len(*configErrors) > 0 { - for _, err := range *configErrors { - logrus.Errorf("Configuration error: %s", err) - } - logrus.Fatalf("Failed to start due to configuration errors") - } - processCtx := process.NewProcessContext() - - internal.SetupStdLogging() - internal.SetupHookLogging(cfg.Logging) - internal.SetupPprof() - - basepkg.PlatformSanityChecks() - - logrus.Infof("Dendrite version %s", internal.VersionString()) - if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled { - logrus.Warn("Open registration is enabled") - } - - // create DNS cache - var dnsCache *fclient.DNSCache - if cfg.Global.DNSCache.Enabled { - dnsCache = fclient.NewDNSCache( - cfg.Global.DNSCache.CacheSize, - cfg.Global.DNSCache.CacheLifetime, - ) - logrus.Infof( - "DNS cache enabled (size %d, lifetime %s)", - cfg.Global.DNSCache.CacheSize, - cfg.Global.DNSCache.CacheLifetime, - ) - } - - // setup tracing - closer, err := cfg.SetupTracing() - if err != nil { - logrus.WithError(err).Panicf("failed to start opentracing") - } - defer closer.Close() // nolint: errcheck - - // setup sentry - if cfg.Global.Sentry.Enabled { - logrus.Info("Setting up Sentry for debugging...") - err = sentry.Init(sentry.ClientOptions{ - Dsn: cfg.Global.Sentry.DSN, - Environment: cfg.Global.Sentry.Environment, - Debug: true, - ServerName: string(cfg.Global.ServerName), - Release: "dendrite@" + internal.VersionString(), - AttachStacktrace: true, - }) - if err != nil { - logrus.WithError(err).Panic("failed to start Sentry") - } - go func() { - processCtx.ComponentStarted() - <-processCtx.WaitForShutdown() - if !sentry.Flush(time.Second * 5) { - logrus.Warnf("failed to flush all Sentry events!") - } - processCtx.ComponentFinished() - }() - } - - federationClient := basepkg.CreateFederationClient(cfg, dnsCache) - httpClient := basepkg.CreateClient(cfg, dnsCache) - - // prepare required dependencies - cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) - routers := httputil.NewRouters() - - caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, caching.EnableMetrics) - natsInstance := jetstream.NATSInstance{} - rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics) - fsAPI := federationapi.NewInternalAPI( - processCtx, cfg, cm, &natsInstance, federationClient, rsAPI, caches, nil, false, - ) - - keyRing := fsAPI.KeyRing() - - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient) - asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) - - // The underlying roomserver implementation needs to be able to call the fedsender. - // This is different to rsAPI which can be the http client which doesn't need this - // dependency. Other components also need updating after their dependencies are up. - rsAPI.SetFederationAPI(fsAPI, keyRing) - rsAPI.SetAppserviceAPI(asAPI) - rsAPI.SetUserAPI(userAPI) - - monolith := setup.Monolith{ - Config: cfg, - Client: httpClient, - FedClient: federationClient, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - // always use the concrete impl here even in -http mode because adding public routes - // must be done on the concrete impl not an HTTP client else fedapi will call itself - FederationAPI: fsAPI, - RoomserverAPI: rsAPI, - UserAPI: userAPI, - } - monolith.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, caching.EnableMetrics) - - if len(cfg.MSCs.MSCs) > 0 { - if err := mscs.Enable(cfg, cm, routers, &monolith, caches); err != nil { - logrus.WithError(err).Fatalf("Failed to enable MSCs") - } - } - - // Expose the matrix APIs directly rather than putting them under a /api path. - go func() { - basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpAddr, nil, nil) - }() - // Handle HTTPS if certificate and key are provided - if *unixSocket == "" && *certFile != "" && *keyFile != "" { - go func() { - basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpsAddr, certFile, keyFile) - }() - } - - // We want to block forever to let the HTTP and HTTPS handler serve the APIs - basepkg.WaitForShutdown(processCtx) -} diff --git a/cmd/dendrite/main_test.go b/cmd/dendrite/main_test.go deleted file mode 100644 index d51bc7434..000000000 --- a/cmd/dendrite/main_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package main - -import ( - "os" - "os/signal" - "strings" - "syscall" - "testing" -) - -// This is an instrumented main, used when running integration tests (sytest) with code coverage. -// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite -// Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml -// Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html -// Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc -func TestMain(_ *testing.T) { - var ( - args []string - ) - - for _, arg := range os.Args { - switch { - case strings.HasPrefix(arg, "DEVEL"): - case strings.HasPrefix(arg, "-test"): - default: - args = append(args, arg) - } - } - // only run the tests if there are args to be passed - if len(args) <= 1 { - return - } - - waitCh := make(chan int, 1) - os.Args = args - go func() { - main() - close(waitCh) - }() - - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGHUP) - - select { - case <-signalCh: - return - case <-waitCh: - return - } -} diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 2cfd649a8..4305c13a9 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -80,19 +80,11 @@ func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Contex } func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { - tokens, err := a.DB.ListRegistrationTokens(ctx, returnAll, valid) - if err != nil { - return nil, err - } - return tokens, nil + return a.DB.ListRegistrationTokens(ctx, returnAll, valid) } func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { - token, err := a.DB.GetRegistrationToken(ctx, tokenString) - if err != nil { - return nil, err - } - return token, nil + return a.DB.GetRegistrationToken(ctx, tokenString) } func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error { @@ -100,11 +92,7 @@ func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Contex } func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) { - token, err := a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes) - if err != nil { - return nil, err - } - return token, nil + return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes) } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 3f85f2093..45b39c892 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -3,12 +3,13 @@ package postgres import ( "context" "database/sql" - "fmt" "time" "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" ) const registrationTokensSchema = ` @@ -89,7 +90,7 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { var existingToken string - stmt := s.selectTokenStatement + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) if err != nil { if err == sql.ErrNoRows { @@ -105,7 +106,7 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex _, err := stmt.ExecContext( ctx, *registrationToken.Token, - nullIfZeroInt32(*registrationToken.UsesAllowed), + nullIfZero(*registrationToken.UsesAllowed), nullIfZero(*registrationToken.ExpiryTime), *registrationToken.Pending, *registrationToken.Completed) @@ -115,111 +116,82 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex return true, nil } -func nullIfZero(value int64) interface{} { - if value == 0 { +func nullIfZero[t constraints.Integer](in t) any { + if in == 0 { return nil } - return value -} - -func nullIfZeroInt32(value int32) interface{} { - if value == 0 { - return nil - } - return value + return in } func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { var stmt *sql.Stmt var tokens []api.RegistrationToken - var tokenString sql.NullString - var pending, completed, usesAllowed sql.NullInt32 - var expiryTime sql.NullInt64 + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 var rows *sql.Rows var err error if returnAll { - stmt = s.listAllTokensStatement + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) rows, err = stmt.QueryContext(ctx) } else if valid { - stmt = s.listValidTokensStatement + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } else { - stmt = s.listInvalidTokenStatement + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } if err != nil { return tokens, err } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") for rows.Next() { err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) if err != nil { return tokens, err } - tokenString := tokenString.String - pending := pending.Int32 - completed := completed.Int32 - usesAllowed := getReturnValueForInt32(usesAllowed) - expiryTime := getReturnValueForInt64(expiryTime) + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime tokenMap := api.RegistrationToken{ Token: &tokenString, - Pending: &pending, - Completed: &completed, + Pending: pending, + Completed: completed, UsesAllowed: usesAllowed, ExpiryTime: expiryTime, } tokens = append(tokens, tokenMap) } - return tokens, nil -} - -func getReturnValueForInt32(value sql.NullInt32) *int32 { - if value.Valid { - returnValue := value.Int32 - return &returnValue - } - return nil -} - -func getReturnValueForInt64(value sql.NullInt64) *int64 { - if value.Valid { - returnValue := value.Int64 - return &returnValue - } - return nil + return tokens, rows.Err() } func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { stmt := sqlutil.TxStmt(tx, s.getTokenStatement) - var pending, completed, usesAllowed sql.NullInt32 - var expiryTime sql.NullInt64 + var pending, completed, usesAllowed *int32 + var expiryTime *int64 err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) if err != nil { return nil, err } token := api.RegistrationToken{ Token: &tokenString, - Pending: &pending.Int32, - Completed: &completed.Int32, - UsesAllowed: getReturnValueForInt32(usesAllowed), - ExpiryTime: getReturnValueForInt64(expiryTime), + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, } return &token, nil } func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { - stmt := s.deleteTokenStatement - res, err := stmt.ExecContext(ctx, tokenString) + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) if err != nil { return err } - count, err := res.RowsAffected() - if err != nil { - return err - } - if count == 0 { - return fmt.Errorf("token: %s does not exists", tokenString) - } return nil } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 481256db1..b7acb2035 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -100,8 +100,12 @@ func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString) } -func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) error { - return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) +func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) + return err + }) + return } func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) { diff --git a/userapi/storage/sqlite3/registration_tokens_table.go b/userapi/storage/sqlite3/registration_tokens_table.go index 47b70d2e1..99c18c557 100644 --- a/userapi/storage/sqlite3/registration_tokens_table.go +++ b/userapi/storage/sqlite3/registration_tokens_table.go @@ -3,12 +3,13 @@ package sqlite3 import ( "context" "database/sql" - "fmt" "time" "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" ) const registrationTokensSchema = ` @@ -89,7 +90,7 @@ func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTabl func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { var existingToken string - stmt := s.selectTokenStatement + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) if err != nil { if err == sql.ErrNoRows { @@ -105,7 +106,7 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex _, err := stmt.ExecContext( ctx, *registrationToken.Token, - nullIfZeroInt32(*registrationToken.UsesAllowed), + nullIfZero(*registrationToken.UsesAllowed), nullIfZero(*registrationToken.ExpiryTime), *registrationToken.Pending, *registrationToken.Completed) @@ -115,111 +116,82 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex return true, nil } -func nullIfZero(value int64) interface{} { - if value == 0 { +func nullIfZero[t constraints.Integer](in t) any { + if in == 0 { return nil } - return value -} - -func nullIfZeroInt32(value int32) interface{} { - if value == 0 { - return nil - } - return value + return in } func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { var stmt *sql.Stmt var tokens []api.RegistrationToken - var tokenString sql.NullString - var pending, completed, usesAllowed sql.NullInt32 - var expiryTime sql.NullInt64 + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 var rows *sql.Rows var err error if returnAll { - stmt = s.listAllTokensStatement + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) rows, err = stmt.QueryContext(ctx) } else if valid { - stmt = s.listValidTokensStatement + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } else { - stmt = s.listInvalidTokenStatement + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } if err != nil { return tokens, err } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") for rows.Next() { err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) if err != nil { return tokens, err } - tokenString := tokenString.String - pending := pending.Int32 - completed := completed.Int32 - usesAllowed := getReturnValueForInt32(usesAllowed) - expiryTime := getReturnValueForInt64(expiryTime) + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime tokenMap := api.RegistrationToken{ Token: &tokenString, - Pending: &pending, - Completed: &completed, + Pending: pending, + Completed: completed, UsesAllowed: usesAllowed, ExpiryTime: expiryTime, } tokens = append(tokens, tokenMap) } - return tokens, nil -} - -func getReturnValueForInt32(value sql.NullInt32) *int32 { - if value.Valid { - returnValue := value.Int32 - return &returnValue - } - return nil -} - -func getReturnValueForInt64(value sql.NullInt64) *int64 { - if value.Valid { - returnValue := value.Int64 - return &returnValue - } - return nil + return tokens, rows.Err() } func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { stmt := sqlutil.TxStmt(tx, s.getTokenStatement) - var pending, completed, usesAllowed sql.NullInt32 - var expiryTime sql.NullInt64 + var pending, completed, usesAllowed *int32 + var expiryTime *int64 err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) if err != nil { return nil, err } token := api.RegistrationToken{ Token: &tokenString, - Pending: &pending.Int32, - Completed: &completed.Int32, - UsesAllowed: getReturnValueForInt32(usesAllowed), - ExpiryTime: getReturnValueForInt64(expiryTime), + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, } return &token, nil } func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { - stmt := s.deleteTokenStatement - res, err := stmt.ExecContext(ctx, tokenString) + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) if err != nil { return err } - count, err := res.RowsAffected() - if err != nil { - return err - } - if count == 0 { - return fmt.Errorf("token: %s does not exists", tokenString) - } return nil }