Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/mxidmapping

This commit is contained in:
Till Faelligen 2023-06-28 08:56:30 +02:00
commit a33ef5b380
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
18 changed files with 1482 additions and 15 deletions

View file

@ -2,6 +2,7 @@ package clientapi
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -23,12 +24,649 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
capi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
) )
func TestAdminCreateToken(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cfg.ClientAPI.RegistrationRequiresToken = true
defer close()
natsInstance := jetstream.NATSInstance{}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
}
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
testCases := []struct {
name string
requestingUser *test.User
requestOpt test.HTTPRequestOpt
wantOK bool
withHeader bool
}{
{
name: "Missing auth",
requestingUser: bob,
wantOK: false,
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"token": "token1",
},
),
},
{
name: "Bob is denied access",
requestingUser: bob,
wantOK: false,
withHeader: true,
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"token": "token2",
},
),
},
{
name: "Alice can create a token without specifyiing any information",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
requestOpt: test.WithJSONBody(t, map[string]interface{}{}),
},
{
name: "Alice can to create a token specifying a name",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"token": "token3",
},
),
},
{
name: "Alice cannot to create a token that already exists",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"token": "token3",
},
),
},
{
name: "Alice can create a token specifying valid params",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"token": "token4",
"uses_allowed": 5,
"expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond),
},
),
},
{
name: "Alice cannot create a token specifying invalid name",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"token": "token@",
},
),
},
{
name: "Alice cannot create a token specifying invalid uses_allowed",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"token": "token5",
"uses_allowed": -1,
},
),
},
{
name: "Alice cannot create a token specifying invalid expiry_time",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"token": "token6",
"expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond),
},
),
},
{
name: "Alice cannot to create a token specifying invalid length",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"length": 80,
},
),
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new")
if tc.requestOpt != nil {
req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new", tc.requestOpt)
}
if tc.withHeader {
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
}
rec := httptest.NewRecorder()
routers.DendriteAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if tc.wantOK && rec.Code != http.StatusOK {
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
}
})
}
func TestAdminListRegistrationTokens(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cfg.ClientAPI.RegistrationRequiresToken = true
defer close()
natsInstance := jetstream.NATSInstance{}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
}
tokens := []capi.RegistrationToken{
{
Token: getPointer("valid"),
UsesAllowed: getPointer(int32(10)),
ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
Pending: getPointer(int32(0)),
Completed: getPointer(int32(0)),
},
{
Token: getPointer("invalid"),
UsesAllowed: getPointer(int32(10)),
ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
Pending: getPointer(int32(0)),
Completed: getPointer(int32(0)),
},
}
for _, tkn := range tokens {
tkn := tkn
userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
}
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
testCases := []struct {
name string
requestingUser *test.User
valid string
isValidSpecified bool
wantOK bool
withHeader bool
}{
{
name: "Missing auth",
requestingUser: bob,
wantOK: false,
isValidSpecified: false,
},
{
name: "Bob is denied access",
requestingUser: bob,
wantOK: false,
withHeader: true,
isValidSpecified: false,
},
{
name: "Alice can list all tokens",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
},
{
name: "Alice can list all valid tokens",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
valid: "true",
isValidSpecified: true,
},
{
name: "Alice can list all invalid tokens",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
valid: "false",
isValidSpecified: true,
},
{
name: "No response when valid has a bad value",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
valid: "trueee",
isValidSpecified: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
var path string
if tc.isValidSpecified {
path = fmt.Sprintf("/_dendrite/admin/registrationTokens?valid=%v", tc.valid)
} else {
path = "/_dendrite/admin/registrationTokens"
}
req := test.NewRequest(t, http.MethodGet, path)
if tc.withHeader {
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
}
rec := httptest.NewRecorder()
routers.DendriteAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if tc.wantOK && rec.Code != http.StatusOK {
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
}
})
}
func TestAdminGetRegistrationToken(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cfg.ClientAPI.RegistrationRequiresToken = true
defer close()
natsInstance := jetstream.NATSInstance{}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
}
tokens := []capi.RegistrationToken{
{
Token: getPointer("alice_token1"),
UsesAllowed: getPointer(int32(10)),
ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
Pending: getPointer(int32(0)),
Completed: getPointer(int32(0)),
},
{
Token: getPointer("alice_token2"),
UsesAllowed: getPointer(int32(10)),
ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
Pending: getPointer(int32(0)),
Completed: getPointer(int32(0)),
},
}
for _, tkn := range tokens {
tkn := tkn
userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
}
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
testCases := []struct {
name string
requestingUser *test.User
token string
wantOK bool
withHeader bool
}{
{
name: "Missing auth",
requestingUser: bob,
wantOK: false,
},
{
name: "Bob is denied access",
requestingUser: bob,
wantOK: false,
withHeader: true,
},
{
name: "Alice can GET alice_token1",
token: "alice_token1",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
},
{
name: "Alice can GET alice_token2",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
token: "alice_token2",
},
{
name: "Alice cannot GET a token that does not exists",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
token: "alice_token3",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token)
req := test.NewRequest(t, http.MethodGet, path)
if tc.withHeader {
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
}
rec := httptest.NewRecorder()
routers.DendriteAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if tc.wantOK && rec.Code != http.StatusOK {
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
}
})
}
func TestAdminDeleteRegistrationToken(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cfg.ClientAPI.RegistrationRequiresToken = true
defer close()
natsInstance := jetstream.NATSInstance{}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
}
tokens := []capi.RegistrationToken{
{
Token: getPointer("alice_token1"),
UsesAllowed: getPointer(int32(10)),
ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
Pending: getPointer(int32(0)),
Completed: getPointer(int32(0)),
},
{
Token: getPointer("alice_token2"),
UsesAllowed: getPointer(int32(10)),
ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
Pending: getPointer(int32(0)),
Completed: getPointer(int32(0)),
},
}
for _, tkn := range tokens {
tkn := tkn
userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
}
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
testCases := []struct {
name string
requestingUser *test.User
token string
wantOK bool
withHeader bool
}{
{
name: "Missing auth",
requestingUser: bob,
wantOK: false,
},
{
name: "Bob is denied access",
requestingUser: bob,
wantOK: false,
withHeader: true,
},
{
name: "Alice can DELETE alice_token1",
token: "alice_token1",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
},
{
name: "Alice can DELETE alice_token2",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
token: "alice_token2",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token)
req := test.NewRequest(t, http.MethodDelete, path)
if tc.withHeader {
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
}
rec := httptest.NewRecorder()
routers.DendriteAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if tc.wantOK && rec.Code != http.StatusOK {
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
}
})
}
func TestAdminUpdateRegistrationToken(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cfg.ClientAPI.RegistrationRequiresToken = true
defer close()
natsInstance := jetstream.NATSInstance{}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
}
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
tokens := []capi.RegistrationToken{
{
Token: getPointer("alice_token1"),
UsesAllowed: getPointer(int32(10)),
ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
Pending: getPointer(int32(0)),
Completed: getPointer(int32(0)),
},
{
Token: getPointer("alice_token2"),
UsesAllowed: getPointer(int32(10)),
ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
Pending: getPointer(int32(0)),
Completed: getPointer(int32(0)),
},
}
for _, tkn := range tokens {
tkn := tkn
userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
}
testCases := []struct {
name string
requestingUser *test.User
method string
token string
requestOpt test.HTTPRequestOpt
wantOK bool
withHeader bool
}{
{
name: "Missing auth",
requestingUser: bob,
wantOK: false,
token: "alice_token1",
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"uses_allowed": 10,
},
),
},
{
name: "Bob is denied access",
requestingUser: bob,
wantOK: false,
withHeader: true,
token: "alice_token1",
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"uses_allowed": 10,
},
),
},
{
name: "Alice can UPDATE a token's uses_allowed property",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
token: "alice_token1",
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"uses_allowed": 10,
}),
},
{
name: "Alice can UPDATE a token's expiry_time property",
requestingUser: aliceAdmin,
wantOK: true,
withHeader: true,
token: "alice_token2",
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond),
},
),
},
{
name: "Alice can UPDATE a token's uses_allowed and expiry_time property",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
token: "alice_token1",
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"uses_allowed": 20,
"expiry_time": time.Now().Add(10*24*time.Hour).UnixNano() / int64(time.Millisecond),
},
),
},
{
name: "Alice CANNOT update a token with invalid properties",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
token: "alice_token2",
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"uses_allowed": -5,
"expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond),
},
),
},
{
name: "Alice CANNOT UPDATE a token that does not exist",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
token: "alice_token9",
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"uses_allowed": 100,
},
),
},
{
name: "Alice can UPDATE token specifying uses_allowed as null - Valid for infinite uses",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
token: "alice_token1",
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"uses_allowed": nil,
},
),
},
{
name: "Alice can UPDATE token specifying expiry_time AS null - Valid for infinite time",
requestingUser: aliceAdmin,
wantOK: false,
withHeader: true,
token: "alice_token1",
requestOpt: test.WithJSONBody(t, map[string]interface{}{
"expiry_time": nil,
},
),
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token)
req := test.NewRequest(t, http.MethodPut, path)
if tc.requestOpt != nil {
req = test.NewRequest(t, http.MethodPut, path, tc.requestOpt)
}
if tc.withHeader {
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
}
rec := httptest.NewRecorder()
routers.DendriteAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if tc.wantOK && rec.Code != http.StatusOK {
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
}
})
}
func getPointer[T any](s T) *T {
return &s
}
func TestAdminResetPassword(t *testing.T) { func TestAdminResetPassword(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))

View file

@ -21,3 +21,11 @@ type ExtraPublicRoomsProvider interface {
// Rooms returns the extra rooms. This is called on-demand by clients, so cache appropriately. // Rooms returns the extra rooms. This is called on-demand by clients, so cache appropriately.
Rooms() []fclient.PublicRoom Rooms() []fclient.PublicRoom
} }
type RegistrationToken struct {
Token *string `json:"token"`
UsesAllowed *int32 `json:"uses_allowed"`
Pending *int32 `json:"pending"`
Completed *int32 `json:"completed"`
ExpiryTime *int64 `json:"expiry_time"`
}

View file

@ -6,6 +6,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"regexp"
"strconv"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -16,14 +18,254 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/exp/constraints"
clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
) )
var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
if !cfg.RegistrationRequiresToken {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"),
}
}
request := struct {
Token string `json:"token"`
UsesAllowed *int32 `json:"uses_allowed,omitempty"`
ExpiryTime *int64 `json:"expiry_time,omitempty"`
Length int32 `json:"length"`
}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
}
}
token := request.Token
usesAllowed := request.UsesAllowed
expiryTime := request.ExpiryTime
length := request.Length
if len(token) == 0 {
if length == 0 {
// length not provided in request. Assign default value of 16.
length = 16
}
// token not present in request body. Hence, generate a random token.
if length <= 0 || length > 64 {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("length must be greater than zero and not greater than 64"),
}
}
token = util.RandomString(int(length))
}
if len(token) > 64 {
//Token present in request body, but is too long.
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("token must not be longer than 64"),
}
}
isTokenValid := validRegistrationTokenRegex.Match([]byte(token))
if !isTokenValid {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"),
}
}
// At this point, we have a valid token, either through request body or through random generation.
if usesAllowed != nil && *usesAllowed < 0 {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
}
}
if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("expiry_time must not be in the past"),
}
}
pending := int32(0)
completed := int32(0)
// If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB)
registrationToken := &clientapi.RegistrationToken{
Token: &token,
UsesAllowed: usesAllowed,
Pending: &pending,
Completed: &completed,
ExpiryTime: expiryTime,
}
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken)
if !created {
return util.JSONResponse{
Code: http.StatusConflict,
JSON: map[string]string{
"error": fmt.Sprintf("token: %s already exists", token),
},
}
}
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: err,
}
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
"token": token,
"uses_allowed": getReturnValue(usesAllowed),
"pending": pending,
"completed": completed,
"expiry_time": getReturnValue(expiryTime),
},
}
}
func getReturnValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return *in
}
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
queryParams := req.URL.Query()
returnAll := true
valid := true
validQuery, ok := queryParams["valid"]
if ok {
returnAll = false
validValue, err := strconv.ParseBool(validQuery[0])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("invalid 'valid' query parameter"),
}
}
valid = validValue
}
tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.ErrorUnknown,
}
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
"registration_tokens": tokens,
},
}
}
func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
tokenText := vars["token"]
token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText)
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
}
}
return util.JSONResponse{
Code: 200,
JSON: token,
}
}
func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
tokenText := vars["token"]
err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: err,
}
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{},
}
}
func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
tokenText := vars["token"]
request := make(map[string]*int64)
if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
}
}
newAttributes := make(map[string]interface{})
usesAllowed, ok := request["uses_allowed"]
if ok {
// Only add usesAllowed to newAtrributes if it is present and valid
if usesAllowed != nil && *usesAllowed < 0 {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
}
}
newAttributes["usesAllowed"] = usesAllowed
}
expiryTime, ok := request["expiry_time"]
if ok {
// Only add expiryTime to newAtrributes if it is present and valid
if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("expiry_time must not be in the past"),
}
}
newAttributes["expiryTime"] = expiryTime
}
if len(newAttributes) == 0 {
// No attributes to update. Return existing token
return AdminGetRegistrationToken(req, cfg, userAPI)
}
updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes)
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
}
}
return util.JSONResponse{
Code: 200,
JSON: *updatedToken,
}
}
func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {

View file

@ -162,6 +162,36 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
} }
dendriteAdminRouter.Handle("/admin/registrationTokens/new",
httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminCreateNewRegistrationToken(req, cfg, userAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/registrationTokens",
httputil.MakeAdminAPI("admin_list_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminListRegistrationTokens(req, cfg, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/registrationTokens/{token}",
httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
switch req.Method {
case http.MethodGet:
return AdminGetRegistrationToken(req, cfg, userAPI)
case http.MethodPut:
return AdminUpdateRegistrationToken(req, cfg, userAPI)
case http.MethodDelete:
return AdminDeleteRegistrationToken(req, cfg, userAPI)
default:
return util.MatrixErrorResponse(
404,
string(spec.ErrorNotFound),
"unknown method",
)
}
}),
).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}",
httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View file

@ -2,7 +2,7 @@
title: Generating signing keys title: Generating signing keys
parent: Manual parent: Manual
grand_parent: Installation grand_parent: Installation
nav_order: 4 nav_order: 3
permalink: /installation/manual/signingkeys permalink: /installation/manual/signingkeys
--- ---

View file

@ -2,7 +2,7 @@
title: Configuring Dendrite title: Configuring Dendrite
parent: Manual parent: Manual
grand_parent: Installation grand_parent: Installation
nav_order: 3 nav_order: 4
permalink: /installation/manual/configuration permalink: /installation/manual/configuration
--- ---
@ -21,7 +21,7 @@ sections:
First of all, you will need to configure the server name of your Matrix homeserver. First of all, you will need to configure the server name of your Matrix homeserver.
This must match the domain name that you have selected whilst [configuring the domain This must match the domain name that you have selected whilst [configuring the domain
name delegation](domainname#delegation). name delegation](../domainname#delegation).
In the `global` section, set the `server_name` to your delegated domain name: In the `global` section, set the `server_name` to your delegated domain name:

View file

@ -30,10 +30,6 @@ func IsServerAllowed(
serverCurrentlyInRoom bool, serverCurrentlyInRoom bool,
authEvents []gomatrixserverlib.PDU, authEvents []gomatrixserverlib.PDU,
) bool { ) bool {
// In practice should not happen, but avoids unneeded CPU cycles
if serverName == "" || len(authEvents) == 0 {
return false
}
historyVisibility := HistoryVisibilityForRoom(authEvents) historyVisibility := HistoryVisibilityForRoom(authEvents)
// 1. If the history_visibility was set to world_readable, allow. // 1. If the history_visibility was set to world_readable, allow.

View file

@ -114,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
if info == nil || info.IsStub() { if info == nil || info.IsStub() {
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
} }
requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers, info.RoomVersion)
// Request 100 items regardless of what the query asks for. // Request 100 items regardless of what the query asks for.
// We don't want to go much higher than this. // We don't want to go much higher than this.
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
@ -265,7 +265,7 @@ type backfillRequester struct {
eventIDToBeforeStateIDs map[string][]string eventIDToBeforeStateIDs map[string][]string
eventIDMap map[string]gomatrixserverlib.PDU eventIDMap map[string]gomatrixserverlib.PDU
historyVisiblity gomatrixserverlib.HistoryVisibility historyVisiblity gomatrixserverlib.HistoryVisibility
roomInfo types.RoomInfo roomVersion gomatrixserverlib.RoomVersion
} }
func newBackfillRequester( func newBackfillRequester(
@ -274,6 +274,7 @@ func newBackfillRequester(
virtualHost spec.ServerName, virtualHost spec.ServerName,
isLocalServerName func(spec.ServerName) bool, isLocalServerName func(spec.ServerName) bool,
bwExtrems map[string][]string, preferServers []spec.ServerName, bwExtrems map[string][]string, preferServers []spec.ServerName,
roomVersion gomatrixserverlib.RoomVersion,
) *backfillRequester { ) *backfillRequester {
preferServer := make(map[spec.ServerName]bool) preferServer := make(map[spec.ServerName]bool)
for _, p := range preferServers { for _, p := range preferServers {
@ -290,6 +291,7 @@ func newBackfillRequester(
bwExtrems: bwExtrems, bwExtrems: bwExtrems,
preferServer: preferServer, preferServer: preferServer,
historyVisiblity: gomatrixserverlib.HistoryVisibilityShared, historyVisiblity: gomatrixserverlib.HistoryVisibilityShared,
roomVersion: roomVersion,
} }
} }
@ -537,15 +539,11 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
} }
eventNIDs := make([]types.EventNID, len(nidMap)) eventNIDs := make([]types.EventNID, len(nidMap))
i := 0 i := 0
roomNID := b.roomInfo.RoomNID
for _, nid := range nidMap { for _, nid := range nidMap {
eventNIDs[i] = nid.EventNID eventNIDs[i] = nid.EventNID
i++ i++
if roomNID == 0 {
roomNID = nid.RoomNID
}
} }
eventsWithNids, err := b.db.Events(ctx, b.roomInfo.RoomVersion, eventNIDs) eventsWithNids, err := b.db.Events(ctx, b.roomVersion, eventNIDs)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err return nil, err

View file

@ -13,6 +13,10 @@ type ClientAPI struct {
// secrets) // secrets)
RegistrationDisabled bool `yaml:"registration_disabled"` RegistrationDisabled bool `yaml:"registration_disabled"`
// If set, requires users to submit a token during registration.
// Tokens can be managed using admin API.
RegistrationRequiresToken bool `yaml:"registration_requires_token"`
// Enable registration without captcha verification or shared secret. // Enable registration without captcha verification or shared secret.
// This option is populated by the -really-enable-open-registration // This option is populated by the -really-enable-open-registration
// command line parameter as it is not recommended. // command line parameter as it is not recommended.
@ -56,6 +60,7 @@ type ClientAPI struct {
func (c *ClientAPI) Defaults(opts DefaultOpts) { func (c *ClientAPI) Defaults(opts DefaultOpts) {
c.RegistrationSharedSecret = "" c.RegistrationSharedSecret = ""
c.RegistrationRequiresToken = false
c.RecaptchaPublicKey = "" c.RecaptchaPublicKey = ""
c.RecaptchaPrivateKey = "" c.RecaptchaPrivateKey = ""
c.RecaptchaEnabled = false c.RecaptchaEnabled = false

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
) )
@ -94,6 +95,11 @@ type ClientUserAPI interface {
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error
PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error

View file

@ -33,6 +33,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
@ -63,6 +64,37 @@ type UserInternalAPI struct {
Updater *DeviceListUpdater Updater *DeviceListUpdater
} }
func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) {
exists, err := a.DB.RegistrationTokenExists(ctx, *registrationToken.Token)
if err != nil {
return false, err
}
if exists {
return false, fmt.Errorf("token: %s already exists", *registrationToken.Token)
}
_, err = a.DB.InsertRegistrationToken(ctx, registrationToken)
if err != nil {
return false, fmt.Errorf("Error creating token: %s"+err.Error(), *registrationToken.Token)
}
return true, nil
}
func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
return a.DB.ListRegistrationTokens(ctx, returnAll, valid)
}
func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
return a.DB.GetRegistrationToken(ctx, tokenString)
}
func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error {
return a.DB.DeleteRegistrationToken(ctx, tokenString)
}
func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) {
return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -30,6 +31,15 @@ import (
"github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/types"
) )
type RegistrationTokens interface {
RegistrationTokenExists(ctx context.Context, token string) (bool, error)
InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
DeleteRegistrationToken(ctx context.Context, tokenString string) error
UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
}
type Profile interface { type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
@ -144,6 +154,7 @@ type UserDatabase interface {
Pusher Pusher
Statistics Statistics
ThreePID ThreePID
RegistrationTokens
} }
type KeyChangeDatabase interface { type KeyChangeDatabase interface {

View file

@ -0,0 +1,222 @@
package postgres
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/clientapi/api"
internal "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"golang.org/x/exp/constraints"
)
const registrationTokensSchema = `
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
token TEXT PRIMARY KEY,
pending BIGINT,
completed BIGINT,
uses_allowed BIGINT,
expiry_time BIGINT
);
`
const selectTokenSQL = "" +
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
const insertTokenSQL = "" +
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
const listAllTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens"
const listValidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
"(expiry_time > $1 OR expiry_time IS NULL)"
const listInvalidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
const getTokenSQL = "" +
"SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
const deleteTokenSQL = "" +
"DELETE FROM userapi_registration_tokens WHERE token = $1"
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
const updateTokenUsesAllowedSQL = "" +
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
const updateTokenExpiryTimeSQL = "" +
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
type registrationTokenStatements struct {
selectTokenStatement *sql.Stmt
insertTokenStatement *sql.Stmt
listAllTokensStatement *sql.Stmt
listValidTokensStatement *sql.Stmt
listInvalidTokenStatement *sql.Stmt
getTokenStatement *sql.Stmt
deleteTokenStatement *sql.Stmt
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
updateTokenUsesAllowedStatement *sql.Stmt
updateTokenExpiryTimeStatement *sql.Stmt
}
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
s := &registrationTokenStatements{}
_, err := db.Exec(registrationTokensSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectTokenStatement, selectTokenSQL},
{&s.insertTokenStatement, insertTokenSQL},
{&s.listAllTokensStatement, listAllTokensSQL},
{&s.listValidTokensStatement, listValidTokensSQL},
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
{&s.getTokenStatement, getTokenSQL},
{&s.deleteTokenStatement, deleteTokenSQL},
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
}.Prepare(db)
}
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
var existingToken string
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
_, err := stmt.ExecContext(
ctx,
*registrationToken.Token,
getInsertValue(registrationToken.UsesAllowed),
getInsertValue(registrationToken.ExpiryTime),
*registrationToken.Pending,
*registrationToken.Completed)
if err != nil {
return false, err
}
return true, nil
}
func getInsertValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return *in
}
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
var stmt *sql.Stmt
var tokens []api.RegistrationToken
var tokenString string
var pending, completed, usesAllowed *int32
var expiryTime *int64
var rows *sql.Rows
var err error
if returnAll {
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
rows, err = stmt.QueryContext(ctx)
} else if valid {
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
} else {
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
}
if err != nil {
return tokens, err
}
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
for rows.Next() {
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return tokens, err
}
tokenString := tokenString
pending := pending
completed := completed
usesAllowed := usesAllowed
expiryTime := expiryTime
tokenMap := api.RegistrationToken{
Token: &tokenString,
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
tokens = append(tokens, tokenMap)
}
return tokens, rows.Err()
}
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
var pending, completed, usesAllowed *int32
var expiryTime *int64
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return nil, err
}
token := api.RegistrationToken{
Token: &tokenString,
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
return &token, nil
}
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
_, err := stmt.ExecContext(ctx, tokenString)
if err != nil {
return err
}
return nil
}
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
var stmt *sql.Stmt
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
if usesAllowedPresent && expiryTimePresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
if err != nil {
return nil, err
}
} else if usesAllowedPresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
if err != nil {
return nil, err
}
} else if expiryTimePresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
if err != nil {
return nil, err
}
}
return s.GetRegistrationToken(ctx, tx, tokenString)
}

View file

@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
return nil, err return nil, err
} }
registationTokensTable, err := NewPostgresRegistrationTokensTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName) accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
ThreePIDs: threePIDTable, ThreePIDs: threePIDTable,
Pushers: pusherTable, Pushers: pusherTable,
Notifications: notificationsTable, Notifications: notificationsTable,
RegistrationTokens: registationTokensTable,
Stats: statsTable, Stats: statsTable,
ServerName: serverName, ServerName: serverName,
DB: db, DB: db,

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -43,6 +44,7 @@ import (
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer Writer sqlutil.Writer
RegistrationTokens tables.RegistrationTokensTable
Accounts tables.AccountsTable Accounts tables.AccountsTable
Profiles tables.ProfileTable Profiles tables.ProfileTable
AccountDatas tables.AccountDataTable AccountDatas tables.AccountDataTable
@ -78,6 +80,42 @@ const (
loginTokenByteLength = 32 loginTokenByteLength = 32
) )
func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) {
return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token)
}
func (d *Database) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (created bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, registrationToken)
return err
})
return
}
func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid)
}
func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString)
}
func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
return err
})
return
}
func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes)
return err
})
return
}
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart. // Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword( func (d *Database) GetAccountByPassword(

View file

@ -0,0 +1,222 @@
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/clientapi/api"
internal "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"golang.org/x/exp/constraints"
)
const registrationTokensSchema = `
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
token TEXT PRIMARY KEY,
pending BIGINT,
completed BIGINT,
uses_allowed BIGINT,
expiry_time BIGINT
);
`
const selectTokenSQL = "" +
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
const insertTokenSQL = "" +
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
const listAllTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens"
const listValidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
"(expiry_time > $1 OR expiry_time IS NULL)"
const listInvalidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
const getTokenSQL = "" +
"SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
const deleteTokenSQL = "" +
"DELETE FROM userapi_registration_tokens WHERE token = $1"
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
const updateTokenUsesAllowedSQL = "" +
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
const updateTokenExpiryTimeSQL = "" +
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
type registrationTokenStatements struct {
selectTokenStatement *sql.Stmt
insertTokenStatement *sql.Stmt
listAllTokensStatement *sql.Stmt
listValidTokensStatement *sql.Stmt
listInvalidTokenStatement *sql.Stmt
getTokenStatement *sql.Stmt
deleteTokenStatement *sql.Stmt
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
updateTokenUsesAllowedStatement *sql.Stmt
updateTokenExpiryTimeStatement *sql.Stmt
}
func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
s := &registrationTokenStatements{}
_, err := db.Exec(registrationTokensSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectTokenStatement, selectTokenSQL},
{&s.insertTokenStatement, insertTokenSQL},
{&s.listAllTokensStatement, listAllTokensSQL},
{&s.listValidTokensStatement, listValidTokensSQL},
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
{&s.getTokenStatement, getTokenSQL},
{&s.deleteTokenStatement, deleteTokenSQL},
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
}.Prepare(db)
}
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
var existingToken string
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
_, err := stmt.ExecContext(
ctx,
*registrationToken.Token,
getInsertValue(registrationToken.UsesAllowed),
getInsertValue(registrationToken.ExpiryTime),
*registrationToken.Pending,
*registrationToken.Completed)
if err != nil {
return false, err
}
return true, nil
}
func getInsertValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return *in
}
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
var stmt *sql.Stmt
var tokens []api.RegistrationToken
var tokenString string
var pending, completed, usesAllowed *int32
var expiryTime *int64
var rows *sql.Rows
var err error
if returnAll {
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
rows, err = stmt.QueryContext(ctx)
} else if valid {
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
} else {
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
}
if err != nil {
return tokens, err
}
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
for rows.Next() {
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return tokens, err
}
tokenString := tokenString
pending := pending
completed := completed
usesAllowed := usesAllowed
expiryTime := expiryTime
tokenMap := api.RegistrationToken{
Token: &tokenString,
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
tokens = append(tokens, tokenMap)
}
return tokens, rows.Err()
}
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
var pending, completed, usesAllowed *int32
var expiryTime *int64
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return nil, err
}
token := api.RegistrationToken{
Token: &tokenString,
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
return &token, nil
}
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
_, err := stmt.ExecContext(ctx, tokenString)
if err != nil {
return err
}
return nil
}
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
var stmt *sql.Stmt
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
if usesAllowedPresent && expiryTimePresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
if err != nil {
return nil, err
}
} else if usesAllowedPresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
if err != nil {
return nil, err
}
} else if expiryTimePresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
if err != nil {
return nil, err
}
}
return s.GetRegistrationToken(ctx, tx, tokenString)
}

View file

@ -50,7 +50,10 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti
if err = m.Up(ctx); err != nil { if err = m.Up(ctx); err != nil {
return nil, err return nil, err
} }
registationTokensTable, err := NewSQLiteRegistrationTokensTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteRegistrationsTokenTable: %w", err)
}
accountsTable, err := NewSQLiteAccountsTable(db, serverName) accountsTable, err := NewSQLiteAccountsTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
@ -130,6 +133,7 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti
LoginTokenLifetime: loginTokenLifetime, LoginTokenLifetime: loginTokenLifetime,
BcryptCost: bcryptCost, BcryptCost: bcryptCost,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
RegistrationTokens: registationTokensTable,
}, nil }, nil
} }

View file

@ -25,10 +25,20 @@ import (
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/types"
) )
type RegistrationTokensTable interface {
RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error)
InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error)
ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error)
DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error
UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
}
type AccountDataTable interface { type AccountDataTable interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)