mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-11 16:13:10 -06:00
addressed review comments
This commit is contained in:
parent
44beddc287
commit
6ea96a0909
|
|
@ -30,13 +30,14 @@ import (
|
|||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
|
||||
|
||||
func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||
if !cfg.RegistrationRequiresToken {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusForbidden,
|
||||
string(spec.ErrorForbidden),
|
||||
"Registration via tokens is not enabled on this homeserver",
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"),
|
||||
}
|
||||
}
|
||||
request := struct {
|
||||
Token string `json:"token"`
|
||||
|
|
@ -46,11 +47,10 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
|
|||
}{}
|
||||
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorBadJSON),
|
||||
"Failed to decode request body:",
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
|
||||
}
|
||||
}
|
||||
|
||||
token := request.Token
|
||||
|
|
@ -65,43 +65,43 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
|
|||
}
|
||||
// token not present in request body. Hence, generate a random token.
|
||||
if !(length > 0 && length <= 64) {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"length must be greater than zero and not greater than 64")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("length must be greater than zero and not greater than 64"),
|
||||
}
|
||||
}
|
||||
token = generateRandomToken(int(length))
|
||||
}
|
||||
|
||||
if len(token) > 64 {
|
||||
//Token present in request body, but is too long.
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"token must not be longer than 64")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("token must not be longer than 64"),
|
||||
}
|
||||
}
|
||||
|
||||
isTokenValid, _ := regexp.MatchString("^[[:ascii:][:digit:]_]*$", token)
|
||||
isTokenValid := validRegistrationTokenRegex.Match([]byte(token))
|
||||
if !isTokenValid {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"token must consist only of characters matched by the regex [A-Za-z0-9-_]")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"),
|
||||
}
|
||||
}
|
||||
// At this point, we have a valid token, either through request body or through random generation.
|
||||
|
||||
if usesAllowed < 0 {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"uses_allowed must be a non-negative integer or null")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
|
||||
}
|
||||
}
|
||||
|
||||
if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"expiry_time must not be in the past")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("expiry_time must not be in the past"),
|
||||
}
|
||||
}
|
||||
pending := int32(0)
|
||||
completed := int32(0)
|
||||
|
|
@ -115,17 +115,16 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
|
|||
}
|
||||
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken)
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorUnknown),
|
||||
err.Error(),
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: err,
|
||||
}
|
||||
}
|
||||
if !created {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
fmt.Sprintf("Token alreaady exists: %s", token))
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusConflict,
|
||||
JSON: fmt.Sprintf("Token already exists: %s", token),
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
|
|
@ -166,21 +165,19 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA
|
|||
returnAll = false
|
||||
validValue, err := strconv.ParseBool(validQuery[0])
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"invalid 'valid' query parameter",
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("invalid 'valid' query parameter"),
|
||||
}
|
||||
}
|
||||
valid = validValue
|
||||
}
|
||||
tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusInternalServerError,
|
||||
string(spec.ErrorUnknown),
|
||||
"error fetching registration tokens",
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.ErrorUnknown,
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
|
|
@ -205,11 +202,10 @@ func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI
|
|||
tokenText := vars["token"]
|
||||
token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText)
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusNotFound,
|
||||
string(spec.ErrorUnknown),
|
||||
fmt.Sprintf("token: %s not found", tokenText),
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
|
|
@ -225,11 +221,10 @@ func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
|
|||
tokenText := vars["token"]
|
||||
err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText)
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusNotFound,
|
||||
string(spec.ErrorUnknown),
|
||||
fmt.Sprintf("token: %s not found", tokenText),
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
|
|
@ -244,12 +239,11 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
|
|||
}
|
||||
tokenText := vars["token"]
|
||||
request := make(map[string]interface{})
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorBadJSON),
|
||||
"Failed to decode request body:",
|
||||
)
|
||||
if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
|
||||
}
|
||||
}
|
||||
newAttributes := make(map[string]interface{})
|
||||
usesAllowed, ok := request["uses_allowed"]
|
||||
|
|
@ -258,11 +252,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
|
|||
// Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic
|
||||
// Synapse's behaviour of updating the field if and only if it is present in request body.
|
||||
if !(usesAllowed == nil || int32(usesAllowed.(float64)) >= 0) {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"uses_allowed must be a non-negative integer or null",
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
|
||||
}
|
||||
}
|
||||
newAttributes["usesAllowed"] = usesAllowed
|
||||
}
|
||||
|
|
@ -272,11 +265,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
|
|||
// Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic
|
||||
// Synapse's behaviour of updating the field if and only if it is present in request body.
|
||||
if !(expiryTime == nil || int64(expiryTime.(float64)) > time.Now().UnixNano()/int64(time.Millisecond)) {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"expiry_time must be in the future",
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("expiry_time must not be in the past"),
|
||||
}
|
||||
}
|
||||
newAttributes["expiryTime"] = expiryTime
|
||||
}
|
||||
|
|
@ -286,11 +278,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
|
|||
}
|
||||
updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes)
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusNotFound,
|
||||
string(spec.ErrorUnknown),
|
||||
fmt.Sprintf("token: %s not found", tokenText),
|
||||
)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
|
|
|
|||
|
|
@ -176,19 +176,20 @@ func Setup(
|
|||
|
||||
dendriteAdminRouter.Handle("/admin/registrationTokens/{token}",
|
||||
httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if req.Method == http.MethodGet {
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
return AdminGetRegistrationToken(req, cfg, userAPI)
|
||||
} else if req.Method == http.MethodPut {
|
||||
case http.MethodPut:
|
||||
return AdminUpdateRegistrationToken(req, cfg, userAPI)
|
||||
} else if req.Method == http.MethodDelete {
|
||||
case http.MethodDelete:
|
||||
return AdminDeleteRegistrationToken(req, cfg, userAPI)
|
||||
default:
|
||||
return util.MatrixErrorResponse(
|
||||
404,
|
||||
string(spec.ErrorNotFound),
|
||||
"unknown method",
|
||||
)
|
||||
}
|
||||
return util.MatrixErrorResponse(
|
||||
404,
|
||||
string(spec.ErrorNotFound),
|
||||
"unknown method",
|
||||
)
|
||||
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions)
|
||||
|
||||
|
|
@ -196,7 +197,7 @@ func Setup(
|
|||
httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return AdminEvacuateRoom(req, rsAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}",
|
||||
httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
|
|||
|
|
@ -1,203 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/matrix-org/dendrite/appservice"
|
||||
"github.com/matrix-org/dendrite/federationapi"
|
||||
"github.com/matrix-org/dendrite/roomserver"
|
||||
"github.com/matrix-org/dendrite/setup"
|
||||
basepkg "github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/mscs"
|
||||
"github.com/matrix-org/dendrite/userapi"
|
||||
)
|
||||
|
||||
var (
|
||||
unixSocket = flag.String("unix-socket", "",
|
||||
"EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)",
|
||||
)
|
||||
unixSocketPermission = flag.String("unix-socket-permission", "755",
|
||||
"EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server (in chmod format like 755)",
|
||||
)
|
||||
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
|
||||
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
|
||||
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
|
||||
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS")
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := setup.ParseFlags(true)
|
||||
httpAddr := config.ServerAddress{}
|
||||
httpsAddr := config.ServerAddress{}
|
||||
if *unixSocket == "" {
|
||||
http, err := config.HTTPAddress("http://" + *httpBindAddr)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatalf("Failed to parse http address")
|
||||
}
|
||||
httpAddr = http
|
||||
https, err := config.HTTPAddress("https://" + *httpsBindAddr)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatalf("Failed to parse https address")
|
||||
}
|
||||
httpsAddr = https
|
||||
} else {
|
||||
socket, err := config.UnixSocketAddress(*unixSocket, *unixSocketPermission)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatalf("Failed to parse unix socket")
|
||||
}
|
||||
httpAddr = socket
|
||||
}
|
||||
|
||||
configErrors := &config.ConfigErrors{}
|
||||
cfg.Verify(configErrors)
|
||||
if len(*configErrors) > 0 {
|
||||
for _, err := range *configErrors {
|
||||
logrus.Errorf("Configuration error: %s", err)
|
||||
}
|
||||
logrus.Fatalf("Failed to start due to configuration errors")
|
||||
}
|
||||
processCtx := process.NewProcessContext()
|
||||
|
||||
internal.SetupStdLogging()
|
||||
internal.SetupHookLogging(cfg.Logging)
|
||||
internal.SetupPprof()
|
||||
|
||||
basepkg.PlatformSanityChecks()
|
||||
|
||||
logrus.Infof("Dendrite version %s", internal.VersionString())
|
||||
if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled {
|
||||
logrus.Warn("Open registration is enabled")
|
||||
}
|
||||
|
||||
// create DNS cache
|
||||
var dnsCache *fclient.DNSCache
|
||||
if cfg.Global.DNSCache.Enabled {
|
||||
dnsCache = fclient.NewDNSCache(
|
||||
cfg.Global.DNSCache.CacheSize,
|
||||
cfg.Global.DNSCache.CacheLifetime,
|
||||
)
|
||||
logrus.Infof(
|
||||
"DNS cache enabled (size %d, lifetime %s)",
|
||||
cfg.Global.DNSCache.CacheSize,
|
||||
cfg.Global.DNSCache.CacheLifetime,
|
||||
)
|
||||
}
|
||||
|
||||
// setup tracing
|
||||
closer, err := cfg.SetupTracing()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start opentracing")
|
||||
}
|
||||
defer closer.Close() // nolint: errcheck
|
||||
|
||||
// setup sentry
|
||||
if cfg.Global.Sentry.Enabled {
|
||||
logrus.Info("Setting up Sentry for debugging...")
|
||||
err = sentry.Init(sentry.ClientOptions{
|
||||
Dsn: cfg.Global.Sentry.DSN,
|
||||
Environment: cfg.Global.Sentry.Environment,
|
||||
Debug: true,
|
||||
ServerName: string(cfg.Global.ServerName),
|
||||
Release: "dendrite@" + internal.VersionString(),
|
||||
AttachStacktrace: true,
|
||||
})
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("failed to start Sentry")
|
||||
}
|
||||
go func() {
|
||||
processCtx.ComponentStarted()
|
||||
<-processCtx.WaitForShutdown()
|
||||
if !sentry.Flush(time.Second * 5) {
|
||||
logrus.Warnf("failed to flush all Sentry events!")
|
||||
}
|
||||
processCtx.ComponentFinished()
|
||||
}()
|
||||
}
|
||||
|
||||
federationClient := basepkg.CreateFederationClient(cfg, dnsCache)
|
||||
httpClient := basepkg.CreateClient(cfg, dnsCache)
|
||||
|
||||
// prepare required dependencies
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
routers := httputil.NewRouters()
|
||||
|
||||
caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, caching.EnableMetrics)
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics)
|
||||
fsAPI := federationapi.NewInternalAPI(
|
||||
processCtx, cfg, cm, &natsInstance, federationClient, rsAPI, caches, nil, false,
|
||||
)
|
||||
|
||||
keyRing := fsAPI.KeyRing()
|
||||
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient)
|
||||
asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI)
|
||||
|
||||
// The underlying roomserver implementation needs to be able to call the fedsender.
|
||||
// This is different to rsAPI which can be the http client which doesn't need this
|
||||
// dependency. Other components also need updating after their dependencies are up.
|
||||
rsAPI.SetFederationAPI(fsAPI, keyRing)
|
||||
rsAPI.SetAppserviceAPI(asAPI)
|
||||
rsAPI.SetUserAPI(userAPI)
|
||||
|
||||
monolith := setup.Monolith{
|
||||
Config: cfg,
|
||||
Client: httpClient,
|
||||
FedClient: federationClient,
|
||||
KeyRing: keyRing,
|
||||
|
||||
AppserviceAPI: asAPI,
|
||||
// always use the concrete impl here even in -http mode because adding public routes
|
||||
// must be done on the concrete impl not an HTTP client else fedapi will call itself
|
||||
FederationAPI: fsAPI,
|
||||
RoomserverAPI: rsAPI,
|
||||
UserAPI: userAPI,
|
||||
}
|
||||
monolith.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, caching.EnableMetrics)
|
||||
|
||||
if len(cfg.MSCs.MSCs) > 0 {
|
||||
if err := mscs.Enable(cfg, cm, routers, &monolith, caches); err != nil {
|
||||
logrus.WithError(err).Fatalf("Failed to enable MSCs")
|
||||
}
|
||||
}
|
||||
|
||||
// Expose the matrix APIs directly rather than putting them under a /api path.
|
||||
go func() {
|
||||
basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpAddr, nil, nil)
|
||||
}()
|
||||
// Handle HTTPS if certificate and key are provided
|
||||
if *unixSocket == "" && *certFile != "" && *keyFile != "" {
|
||||
go func() {
|
||||
basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpsAddr, certFile, keyFile)
|
||||
}()
|
||||
}
|
||||
|
||||
// We want to block forever to let the HTTP and HTTPS handler serve the APIs
|
||||
basepkg.WaitForShutdown(processCtx)
|
||||
}
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// This is an instrumented main, used when running integration tests (sytest) with code coverage.
|
||||
// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite
|
||||
// Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml
|
||||
// Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html
|
||||
// Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc
|
||||
func TestMain(_ *testing.T) {
|
||||
var (
|
||||
args []string
|
||||
)
|
||||
|
||||
for _, arg := range os.Args {
|
||||
switch {
|
||||
case strings.HasPrefix(arg, "DEVEL"):
|
||||
case strings.HasPrefix(arg, "-test"):
|
||||
default:
|
||||
args = append(args, arg)
|
||||
}
|
||||
}
|
||||
// only run the tests if there are args to be passed
|
||||
if len(args) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
waitCh := make(chan int, 1)
|
||||
os.Args = args
|
||||
go func() {
|
||||
main()
|
||||
close(waitCh)
|
||||
}()
|
||||
|
||||
signalCh := make(chan os.Signal, 1)
|
||||
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
|
||||
select {
|
||||
case <-signalCh:
|
||||
return
|
||||
case <-waitCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -80,19 +80,11 @@ func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Contex
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
|
||||
tokens, err := a.DB.ListRegistrationTokens(ctx, returnAll, valid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens, nil
|
||||
return a.DB.ListRegistrationTokens(ctx, returnAll, valid)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
|
||||
token, err := a.DB.GetRegistrationToken(ctx, tokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
return a.DB.GetRegistrationToken(ctx, tokenString)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error {
|
||||
|
|
@ -100,11 +92,7 @@ func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Contex
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) {
|
||||
token, err := a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/api"
|
||||
internal "github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
const registrationTokensSchema = `
|
||||
|
|
@ -89,7 +90,7 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa
|
|||
|
||||
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
|
||||
var existingToken string
|
||||
stmt := s.selectTokenStatement
|
||||
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
|
||||
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -105,7 +106,7 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
|
|||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
*registrationToken.Token,
|
||||
nullIfZeroInt32(*registrationToken.UsesAllowed),
|
||||
nullIfZero(*registrationToken.UsesAllowed),
|
||||
nullIfZero(*registrationToken.ExpiryTime),
|
||||
*registrationToken.Pending,
|
||||
*registrationToken.Completed)
|
||||
|
|
@ -115,111 +116,82 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
|
|||
return true, nil
|
||||
}
|
||||
|
||||
func nullIfZero(value int64) interface{} {
|
||||
if value == 0 {
|
||||
func nullIfZero[t constraints.Integer](in t) any {
|
||||
if in == 0 {
|
||||
return nil
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func nullIfZeroInt32(value int32) interface{} {
|
||||
if value == 0 {
|
||||
return nil
|
||||
}
|
||||
return value
|
||||
return in
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
||||
var stmt *sql.Stmt
|
||||
var tokens []api.RegistrationToken
|
||||
var tokenString sql.NullString
|
||||
var pending, completed, usesAllowed sql.NullInt32
|
||||
var expiryTime sql.NullInt64
|
||||
var tokenString string
|
||||
var pending, completed, usesAllowed *int32
|
||||
var expiryTime *int64
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if returnAll {
|
||||
stmt = s.listAllTokensStatement
|
||||
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
|
||||
rows, err = stmt.QueryContext(ctx)
|
||||
} else if valid {
|
||||
stmt = s.listValidTokensStatement
|
||||
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
} else {
|
||||
stmt = s.listInvalidTokenStatement
|
||||
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
}
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
tokenString := tokenString.String
|
||||
pending := pending.Int32
|
||||
completed := completed.Int32
|
||||
usesAllowed := getReturnValueForInt32(usesAllowed)
|
||||
expiryTime := getReturnValueForInt64(expiryTime)
|
||||
tokenString := tokenString
|
||||
pending := pending
|
||||
completed := completed
|
||||
usesAllowed := usesAllowed
|
||||
expiryTime := expiryTime
|
||||
|
||||
tokenMap := api.RegistrationToken{
|
||||
Token: &tokenString,
|
||||
Pending: &pending,
|
||||
Completed: &completed,
|
||||
Pending: pending,
|
||||
Completed: completed,
|
||||
UsesAllowed: usesAllowed,
|
||||
ExpiryTime: expiryTime,
|
||||
}
|
||||
tokens = append(tokens, tokenMap)
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func getReturnValueForInt32(value sql.NullInt32) *int32 {
|
||||
if value.Valid {
|
||||
returnValue := value.Int32
|
||||
return &returnValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getReturnValueForInt64(value sql.NullInt64) *int64 {
|
||||
if value.Valid {
|
||||
returnValue := value.Int64
|
||||
return &returnValue
|
||||
}
|
||||
return nil
|
||||
return tokens, rows.Err()
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
|
||||
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
|
||||
var pending, completed, usesAllowed sql.NullInt32
|
||||
var expiryTime sql.NullInt64
|
||||
var pending, completed, usesAllowed *int32
|
||||
var expiryTime *int64
|
||||
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token := api.RegistrationToken{
|
||||
Token: &tokenString,
|
||||
Pending: &pending.Int32,
|
||||
Completed: &completed.Int32,
|
||||
UsesAllowed: getReturnValueForInt32(usesAllowed),
|
||||
ExpiryTime: getReturnValueForInt64(expiryTime),
|
||||
Pending: pending,
|
||||
Completed: completed,
|
||||
UsesAllowed: usesAllowed,
|
||||
ExpiryTime: expiryTime,
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
|
||||
stmt := s.deleteTokenStatement
|
||||
res, err := stmt.ExecContext(ctx, tokenString)
|
||||
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
count, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
return fmt.Errorf("token: %s does not exists", tokenString)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -100,8 +100,12 @@ func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string)
|
|||
return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString)
|
||||
}
|
||||
|
||||
func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) error {
|
||||
return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
|
||||
func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) {
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/api"
|
||||
internal "github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
const registrationTokensSchema = `
|
||||
|
|
@ -89,7 +90,7 @@ func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTabl
|
|||
|
||||
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
|
||||
var existingToken string
|
||||
stmt := s.selectTokenStatement
|
||||
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
|
||||
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -105,7 +106,7 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
|
|||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
*registrationToken.Token,
|
||||
nullIfZeroInt32(*registrationToken.UsesAllowed),
|
||||
nullIfZero(*registrationToken.UsesAllowed),
|
||||
nullIfZero(*registrationToken.ExpiryTime),
|
||||
*registrationToken.Pending,
|
||||
*registrationToken.Completed)
|
||||
|
|
@ -115,111 +116,82 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
|
|||
return true, nil
|
||||
}
|
||||
|
||||
func nullIfZero(value int64) interface{} {
|
||||
if value == 0 {
|
||||
func nullIfZero[t constraints.Integer](in t) any {
|
||||
if in == 0 {
|
||||
return nil
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func nullIfZeroInt32(value int32) interface{} {
|
||||
if value == 0 {
|
||||
return nil
|
||||
}
|
||||
return value
|
||||
return in
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
||||
var stmt *sql.Stmt
|
||||
var tokens []api.RegistrationToken
|
||||
var tokenString sql.NullString
|
||||
var pending, completed, usesAllowed sql.NullInt32
|
||||
var expiryTime sql.NullInt64
|
||||
var tokenString string
|
||||
var pending, completed, usesAllowed *int32
|
||||
var expiryTime *int64
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if returnAll {
|
||||
stmt = s.listAllTokensStatement
|
||||
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
|
||||
rows, err = stmt.QueryContext(ctx)
|
||||
} else if valid {
|
||||
stmt = s.listValidTokensStatement
|
||||
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
} else {
|
||||
stmt = s.listInvalidTokenStatement
|
||||
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
}
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
tokenString := tokenString.String
|
||||
pending := pending.Int32
|
||||
completed := completed.Int32
|
||||
usesAllowed := getReturnValueForInt32(usesAllowed)
|
||||
expiryTime := getReturnValueForInt64(expiryTime)
|
||||
tokenString := tokenString
|
||||
pending := pending
|
||||
completed := completed
|
||||
usesAllowed := usesAllowed
|
||||
expiryTime := expiryTime
|
||||
|
||||
tokenMap := api.RegistrationToken{
|
||||
Token: &tokenString,
|
||||
Pending: &pending,
|
||||
Completed: &completed,
|
||||
Pending: pending,
|
||||
Completed: completed,
|
||||
UsesAllowed: usesAllowed,
|
||||
ExpiryTime: expiryTime,
|
||||
}
|
||||
tokens = append(tokens, tokenMap)
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func getReturnValueForInt32(value sql.NullInt32) *int32 {
|
||||
if value.Valid {
|
||||
returnValue := value.Int32
|
||||
return &returnValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getReturnValueForInt64(value sql.NullInt64) *int64 {
|
||||
if value.Valid {
|
||||
returnValue := value.Int64
|
||||
return &returnValue
|
||||
}
|
||||
return nil
|
||||
return tokens, rows.Err()
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
|
||||
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
|
||||
var pending, completed, usesAllowed sql.NullInt32
|
||||
var expiryTime sql.NullInt64
|
||||
var pending, completed, usesAllowed *int32
|
||||
var expiryTime *int64
|
||||
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token := api.RegistrationToken{
|
||||
Token: &tokenString,
|
||||
Pending: &pending.Int32,
|
||||
Completed: &completed.Int32,
|
||||
UsesAllowed: getReturnValueForInt32(usesAllowed),
|
||||
ExpiryTime: getReturnValueForInt64(expiryTime),
|
||||
Pending: pending,
|
||||
Completed: completed,
|
||||
UsesAllowed: usesAllowed,
|
||||
ExpiryTime: expiryTime,
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
|
||||
stmt := s.deleteTokenStatement
|
||||
res, err := stmt.ExecContext(ctx, tokenString)
|
||||
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
count, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
return fmt.Errorf("token: %s does not exists", tokenString)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue