addressed review comments

This commit is contained in:
santhoshivan23 2023-06-14 22:51:56 +05:30
parent 44beddc287
commit 6ea96a0909
8 changed files with 152 additions and 477 deletions

View file

@ -30,13 +30,14 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api"
)
var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
if !cfg.RegistrationRequiresToken {
return util.MatrixErrorResponse(
http.StatusForbidden,
string(spec.ErrorForbidden),
"Registration via tokens is not enabled on this homeserver",
)
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"),
}
}
request := struct {
Token string `json:"token"`
@ -46,11 +47,10 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorBadJSON),
"Failed to decode request body:",
)
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
}
}
token := request.Token
@ -65,43 +65,43 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
}
// token not present in request body. Hence, generate a random token.
if !(length > 0 && length <= 64) {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"length must be greater than zero and not greater than 64")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("length must be greater than zero and not greater than 64"),
}
}
token = generateRandomToken(int(length))
}
if len(token) > 64 {
//Token present in request body, but is too long.
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"token must not be longer than 64")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("token must not be longer than 64"),
}
}
isTokenValid, _ := regexp.MatchString("^[[:ascii:][:digit:]_]*$", token)
isTokenValid := validRegistrationTokenRegex.Match([]byte(token))
if !isTokenValid {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"token must consist only of characters matched by the regex [A-Za-z0-9-_]")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"),
}
}
// At this point, we have a valid token, either through request body or through random generation.
if usesAllowed < 0 {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"uses_allowed must be a non-negative integer or null")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
}
}
if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"expiry_time must not be in the past")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("expiry_time must not be in the past"),
}
}
pending := int32(0)
completed := int32(0)
@ -115,17 +115,16 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
}
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken)
if err != nil {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorUnknown),
err.Error(),
)
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: err,
}
}
if !created {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
fmt.Sprintf("Token alreaady exists: %s", token))
return util.JSONResponse{
Code: http.StatusConflict,
JSON: fmt.Sprintf("Token already exists: %s", token),
}
}
return util.JSONResponse{
Code: 200,
@ -166,21 +165,19 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA
returnAll = false
validValue, err := strconv.ParseBool(validQuery[0])
if err != nil {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"invalid 'valid' query parameter",
)
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("invalid 'valid' query parameter"),
}
}
valid = validValue
}
tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
if err != nil {
return util.MatrixErrorResponse(
http.StatusInternalServerError,
string(spec.ErrorUnknown),
"error fetching registration tokens",
)
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.ErrorUnknown,
}
}
return util.JSONResponse{
Code: 200,
@ -205,11 +202,10 @@ func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI
tokenText := vars["token"]
token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText)
if err != nil {
return util.MatrixErrorResponse(
http.StatusNotFound,
string(spec.ErrorUnknown),
fmt.Sprintf("token: %s not found", tokenText),
)
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
}
}
return util.JSONResponse{
Code: 200,
@ -225,11 +221,10 @@ func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
tokenText := vars["token"]
err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText)
if err != nil {
return util.MatrixErrorResponse(
http.StatusNotFound,
string(spec.ErrorUnknown),
fmt.Sprintf("token: %s not found", tokenText),
)
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
}
}
return util.JSONResponse{
Code: 200,
@ -244,12 +239,11 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
}
tokenText := vars["token"]
request := make(map[string]interface{})
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorBadJSON),
"Failed to decode request body:",
)
if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
}
}
newAttributes := make(map[string]interface{})
usesAllowed, ok := request["uses_allowed"]
@ -258,11 +252,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
// Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic
// Synapse's behaviour of updating the field if and only if it is present in request body.
if !(usesAllowed == nil || int32(usesAllowed.(float64)) >= 0) {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"uses_allowed must be a non-negative integer or null",
)
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
}
}
newAttributes["usesAllowed"] = usesAllowed
}
@ -272,11 +265,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
// Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic
// Synapse's behaviour of updating the field if and only if it is present in request body.
if !(expiryTime == nil || int64(expiryTime.(float64)) > time.Now().UnixNano()/int64(time.Millisecond)) {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"expiry_time must be in the future",
)
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("expiry_time must not be in the past"),
}
}
newAttributes["expiryTime"] = expiryTime
}
@ -286,11 +278,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
}
updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes)
if err != nil {
return util.MatrixErrorResponse(
http.StatusNotFound,
string(spec.ErrorUnknown),
fmt.Sprintf("token: %s not found", tokenText),
)
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
}
}
return util.JSONResponse{
Code: 200,

View file

@ -176,19 +176,20 @@ func Setup(
dendriteAdminRouter.Handle("/admin/registrationTokens/{token}",
httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if req.Method == http.MethodGet {
switch req.Method {
case http.MethodGet:
return AdminGetRegistrationToken(req, cfg, userAPI)
} else if req.Method == http.MethodPut {
case http.MethodPut:
return AdminUpdateRegistrationToken(req, cfg, userAPI)
} else if req.Method == http.MethodDelete {
case http.MethodDelete:
return AdminDeleteRegistrationToken(req, cfg, userAPI)
default:
return util.MatrixErrorResponse(
404,
string(spec.ErrorNotFound),
"unknown method",
)
}
return util.MatrixErrorResponse(
404,
string(spec.ErrorNotFound),
"unknown method",
)
}),
).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions)
@ -196,7 +197,7 @@ func Setup(
httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminEvacuateRoom(req, rsAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
).Methods(http.MethodPost, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}",
httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View file

@ -1,203 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"flag"
"time"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
basepkg "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs"
"github.com/matrix-org/dendrite/userapi"
)
var (
unixSocket = flag.String("unix-socket", "",
"EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)",
)
unixSocketPermission = flag.String("unix-socket-permission", "755",
"EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server (in chmod format like 755)",
)
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS")
)
func main() {
cfg := setup.ParseFlags(true)
httpAddr := config.ServerAddress{}
httpsAddr := config.ServerAddress{}
if *unixSocket == "" {
http, err := config.HTTPAddress("http://" + *httpBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse http address")
}
httpAddr = http
https, err := config.HTTPAddress("https://" + *httpsBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse https address")
}
httpsAddr = https
} else {
socket, err := config.UnixSocketAddress(*unixSocket, *unixSocketPermission)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse unix socket")
}
httpAddr = socket
}
configErrors := &config.ConfigErrors{}
cfg.Verify(configErrors)
if len(*configErrors) > 0 {
for _, err := range *configErrors {
logrus.Errorf("Configuration error: %s", err)
}
logrus.Fatalf("Failed to start due to configuration errors")
}
processCtx := process.NewProcessContext()
internal.SetupStdLogging()
internal.SetupHookLogging(cfg.Logging)
internal.SetupPprof()
basepkg.PlatformSanityChecks()
logrus.Infof("Dendrite version %s", internal.VersionString())
if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled {
logrus.Warn("Open registration is enabled")
}
// create DNS cache
var dnsCache *fclient.DNSCache
if cfg.Global.DNSCache.Enabled {
dnsCache = fclient.NewDNSCache(
cfg.Global.DNSCache.CacheSize,
cfg.Global.DNSCache.CacheLifetime,
)
logrus.Infof(
"DNS cache enabled (size %d, lifetime %s)",
cfg.Global.DNSCache.CacheSize,
cfg.Global.DNSCache.CacheLifetime,
)
}
// setup tracing
closer, err := cfg.SetupTracing()
if err != nil {
logrus.WithError(err).Panicf("failed to start opentracing")
}
defer closer.Close() // nolint: errcheck
// setup sentry
if cfg.Global.Sentry.Enabled {
logrus.Info("Setting up Sentry for debugging...")
err = sentry.Init(sentry.ClientOptions{
Dsn: cfg.Global.Sentry.DSN,
Environment: cfg.Global.Sentry.Environment,
Debug: true,
ServerName: string(cfg.Global.ServerName),
Release: "dendrite@" + internal.VersionString(),
AttachStacktrace: true,
})
if err != nil {
logrus.WithError(err).Panic("failed to start Sentry")
}
go func() {
processCtx.ComponentStarted()
<-processCtx.WaitForShutdown()
if !sentry.Flush(time.Second * 5) {
logrus.Warnf("failed to flush all Sentry events!")
}
processCtx.ComponentFinished()
}()
}
federationClient := basepkg.CreateFederationClient(cfg, dnsCache)
httpClient := basepkg.CreateClient(cfg, dnsCache)
// prepare required dependencies
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
routers := httputil.NewRouters()
caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, caching.EnableMetrics)
natsInstance := jetstream.NATSInstance{}
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics)
fsAPI := federationapi.NewInternalAPI(
processCtx, cfg, cm, &natsInstance, federationClient, rsAPI, caches, nil, false,
)
keyRing := fsAPI.KeyRing()
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient)
asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI)
// The underlying roomserver implementation needs to be able to call the fedsender.
// This is different to rsAPI which can be the http client which doesn't need this
// dependency. Other components also need updating after their dependencies are up.
rsAPI.SetFederationAPI(fsAPI, keyRing)
rsAPI.SetAppserviceAPI(asAPI)
rsAPI.SetUserAPI(userAPI)
monolith := setup.Monolith{
Config: cfg,
Client: httpClient,
FedClient: federationClient,
KeyRing: keyRing,
AppserviceAPI: asAPI,
// always use the concrete impl here even in -http mode because adding public routes
// must be done on the concrete impl not an HTTP client else fedapi will call itself
FederationAPI: fsAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
}
monolith.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, caching.EnableMetrics)
if len(cfg.MSCs.MSCs) > 0 {
if err := mscs.Enable(cfg, cm, routers, &monolith, caches); err != nil {
logrus.WithError(err).Fatalf("Failed to enable MSCs")
}
}
// Expose the matrix APIs directly rather than putting them under a /api path.
go func() {
basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpAddr, nil, nil)
}()
// Handle HTTPS if certificate and key are provided
if *unixSocket == "" && *certFile != "" && *keyFile != "" {
go func() {
basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpsAddr, certFile, keyFile)
}()
}
// We want to block forever to let the HTTP and HTTPS handler serve the APIs
basepkg.WaitForShutdown(processCtx)
}

View file

@ -1,50 +0,0 @@
package main
import (
"os"
"os/signal"
"strings"
"syscall"
"testing"
)
// This is an instrumented main, used when running integration tests (sytest) with code coverage.
// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite
// Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml
// Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html
// Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc
func TestMain(_ *testing.T) {
var (
args []string
)
for _, arg := range os.Args {
switch {
case strings.HasPrefix(arg, "DEVEL"):
case strings.HasPrefix(arg, "-test"):
default:
args = append(args, arg)
}
}
// only run the tests if there are args to be passed
if len(args) <= 1 {
return
}
waitCh := make(chan int, 1)
os.Args = args
go func() {
main()
close(waitCh)
}()
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGHUP)
select {
case <-signalCh:
return
case <-waitCh:
return
}
}

View file

@ -80,19 +80,11 @@ func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Contex
}
func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
tokens, err := a.DB.ListRegistrationTokens(ctx, returnAll, valid)
if err != nil {
return nil, err
}
return tokens, nil
return a.DB.ListRegistrationTokens(ctx, returnAll, valid)
}
func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
token, err := a.DB.GetRegistrationToken(ctx, tokenString)
if err != nil {
return nil, err
}
return token, nil
return a.DB.GetRegistrationToken(ctx, tokenString)
}
func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error {
@ -100,11 +92,7 @@ func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Contex
}
func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) {
token, err := a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
if err != nil {
return nil, err
}
return token, nil
return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {

View file

@ -3,12 +3,13 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/clientapi/api"
internal "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"golang.org/x/exp/constraints"
)
const registrationTokensSchema = `
@ -89,7 +90,7 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
var existingToken string
stmt := s.selectTokenStatement
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
if err != nil {
if err == sql.ErrNoRows {
@ -105,7 +106,7 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
_, err := stmt.ExecContext(
ctx,
*registrationToken.Token,
nullIfZeroInt32(*registrationToken.UsesAllowed),
nullIfZero(*registrationToken.UsesAllowed),
nullIfZero(*registrationToken.ExpiryTime),
*registrationToken.Pending,
*registrationToken.Completed)
@ -115,111 +116,82 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
return true, nil
}
func nullIfZero(value int64) interface{} {
if value == 0 {
func nullIfZero[t constraints.Integer](in t) any {
if in == 0 {
return nil
}
return value
}
func nullIfZeroInt32(value int32) interface{} {
if value == 0 {
return nil
}
return value
return in
}
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
var stmt *sql.Stmt
var tokens []api.RegistrationToken
var tokenString sql.NullString
var pending, completed, usesAllowed sql.NullInt32
var expiryTime sql.NullInt64
var tokenString string
var pending, completed, usesAllowed *int32
var expiryTime *int64
var rows *sql.Rows
var err error
if returnAll {
stmt = s.listAllTokensStatement
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
rows, err = stmt.QueryContext(ctx)
} else if valid {
stmt = s.listValidTokensStatement
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
} else {
stmt = s.listInvalidTokenStatement
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
}
if err != nil {
return tokens, err
}
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
for rows.Next() {
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return tokens, err
}
tokenString := tokenString.String
pending := pending.Int32
completed := completed.Int32
usesAllowed := getReturnValueForInt32(usesAllowed)
expiryTime := getReturnValueForInt64(expiryTime)
tokenString := tokenString
pending := pending
completed := completed
usesAllowed := usesAllowed
expiryTime := expiryTime
tokenMap := api.RegistrationToken{
Token: &tokenString,
Pending: &pending,
Completed: &completed,
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
tokens = append(tokens, tokenMap)
}
return tokens, nil
}
func getReturnValueForInt32(value sql.NullInt32) *int32 {
if value.Valid {
returnValue := value.Int32
return &returnValue
}
return nil
}
func getReturnValueForInt64(value sql.NullInt64) *int64 {
if value.Valid {
returnValue := value.Int64
return &returnValue
}
return nil
return tokens, rows.Err()
}
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
var pending, completed, usesAllowed sql.NullInt32
var expiryTime sql.NullInt64
var pending, completed, usesAllowed *int32
var expiryTime *int64
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return nil, err
}
token := api.RegistrationToken{
Token: &tokenString,
Pending: &pending.Int32,
Completed: &completed.Int32,
UsesAllowed: getReturnValueForInt32(usesAllowed),
ExpiryTime: getReturnValueForInt64(expiryTime),
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
return &token, nil
}
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
stmt := s.deleteTokenStatement
res, err := stmt.ExecContext(ctx, tokenString)
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
_, err := stmt.ExecContext(ctx, tokenString)
if err != nil {
return err
}
count, err := res.RowsAffected()
if err != nil {
return err
}
if count == 0 {
return fmt.Errorf("token: %s does not exists", tokenString)
}
return nil
}

View file

@ -100,8 +100,12 @@ func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string)
return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString)
}
func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) error {
return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
return err
})
return
}
func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) {

View file

@ -3,12 +3,13 @@ package sqlite3
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/clientapi/api"
internal "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"golang.org/x/exp/constraints"
)
const registrationTokensSchema = `
@ -89,7 +90,7 @@ func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTabl
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
var existingToken string
stmt := s.selectTokenStatement
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
if err != nil {
if err == sql.ErrNoRows {
@ -105,7 +106,7 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
_, err := stmt.ExecContext(
ctx,
*registrationToken.Token,
nullIfZeroInt32(*registrationToken.UsesAllowed),
nullIfZero(*registrationToken.UsesAllowed),
nullIfZero(*registrationToken.ExpiryTime),
*registrationToken.Pending,
*registrationToken.Completed)
@ -115,111 +116,82 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
return true, nil
}
func nullIfZero(value int64) interface{} {
if value == 0 {
func nullIfZero[t constraints.Integer](in t) any {
if in == 0 {
return nil
}
return value
}
func nullIfZeroInt32(value int32) interface{} {
if value == 0 {
return nil
}
return value
return in
}
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
var stmt *sql.Stmt
var tokens []api.RegistrationToken
var tokenString sql.NullString
var pending, completed, usesAllowed sql.NullInt32
var expiryTime sql.NullInt64
var tokenString string
var pending, completed, usesAllowed *int32
var expiryTime *int64
var rows *sql.Rows
var err error
if returnAll {
stmt = s.listAllTokensStatement
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
rows, err = stmt.QueryContext(ctx)
} else if valid {
stmt = s.listValidTokensStatement
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
} else {
stmt = s.listInvalidTokenStatement
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
}
if err != nil {
return tokens, err
}
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
for rows.Next() {
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return tokens, err
}
tokenString := tokenString.String
pending := pending.Int32
completed := completed.Int32
usesAllowed := getReturnValueForInt32(usesAllowed)
expiryTime := getReturnValueForInt64(expiryTime)
tokenString := tokenString
pending := pending
completed := completed
usesAllowed := usesAllowed
expiryTime := expiryTime
tokenMap := api.RegistrationToken{
Token: &tokenString,
Pending: &pending,
Completed: &completed,
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
tokens = append(tokens, tokenMap)
}
return tokens, nil
}
func getReturnValueForInt32(value sql.NullInt32) *int32 {
if value.Valid {
returnValue := value.Int32
return &returnValue
}
return nil
}
func getReturnValueForInt64(value sql.NullInt64) *int64 {
if value.Valid {
returnValue := value.Int64
return &returnValue
}
return nil
return tokens, rows.Err()
}
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
var pending, completed, usesAllowed sql.NullInt32
var expiryTime sql.NullInt64
var pending, completed, usesAllowed *int32
var expiryTime *int64
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return nil, err
}
token := api.RegistrationToken{
Token: &tokenString,
Pending: &pending.Int32,
Completed: &completed.Int32,
UsesAllowed: getReturnValueForInt32(usesAllowed),
ExpiryTime: getReturnValueForInt64(expiryTime),
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
return &token, nil
}
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
stmt := s.deleteTokenStatement
res, err := stmt.ExecContext(ctx, tokenString)
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
_, err := stmt.ExecContext(ctx, tokenString)
if err != nil {
return err
}
count, err := res.RowsAffected()
if err != nil {
return err
}
if count == 0 {
return fmt.Errorf("token: %s does not exists", tokenString)
}
return nil
}