mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 14:21:55 -06:00
feat: admin APIs for token authenticated registration (#3101)
### Pull Request Checklist <!-- Please read https://matrix-org.github.io/dendrite/development/contributing before submitting your pull request --> * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Santhoshivan Amudhan santhoshivan23@gmail.com`
This commit is contained in:
parent
a734b112c6
commit
45082d4dce
|
@ -2,6 +2,7 @@ package clientapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -23,12 +24,649 @@ import (
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
capi "github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/test/testrig"
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
"github.com/matrix-org/dendrite/userapi"
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestAdminCreateToken(t *testing.T) {
|
||||||
|
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
|
||||||
|
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
cfg.ClientAPI.RegistrationRequiresToken = true
|
||||||
|
defer close()
|
||||||
|
natsInstance := jetstream.NATSInstance{}
|
||||||
|
routers := httputil.NewRouters()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||||
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||||
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||||
|
accessTokens := map[*test.User]userDevice{
|
||||||
|
aliceAdmin: {},
|
||||||
|
bob: {},
|
||||||
|
}
|
||||||
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
requestingUser *test.User
|
||||||
|
requestOpt test.HTTPRequestOpt
|
||||||
|
wantOK bool
|
||||||
|
withHeader bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Missing auth",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"token": "token1",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bob is denied access",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"token": "token2",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can create a token without specifyiing any information",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can to create a token specifying a name",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"token": "token3",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice cannot to create a token that already exists",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"token": "token3",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can create a token specifying valid params",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"token": "token4",
|
||||||
|
"uses_allowed": 5,
|
||||||
|
"expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice cannot create a token specifying invalid name",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"token": "token@",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice cannot create a token specifying invalid uses_allowed",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"token": "token5",
|
||||||
|
"uses_allowed": -1,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice cannot create a token specifying invalid expiry_time",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"token": "token6",
|
||||||
|
"expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice cannot to create a token specifying invalid length",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"length": 80,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new")
|
||||||
|
if tc.requestOpt != nil {
|
||||||
|
req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new", tc.requestOpt)
|
||||||
|
}
|
||||||
|
if tc.withHeader {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
routers.DendriteAdmin.ServeHTTP(rec, req)
|
||||||
|
t.Logf("%s", rec.Body.String())
|
||||||
|
if tc.wantOK && rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListRegistrationTokens(t *testing.T) {
|
||||||
|
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
|
||||||
|
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
cfg.ClientAPI.RegistrationRequiresToken = true
|
||||||
|
defer close()
|
||||||
|
natsInstance := jetstream.NATSInstance{}
|
||||||
|
routers := httputil.NewRouters()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||||
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||||
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||||
|
accessTokens := map[*test.User]userDevice{
|
||||||
|
aliceAdmin: {},
|
||||||
|
bob: {},
|
||||||
|
}
|
||||||
|
tokens := []capi.RegistrationToken{
|
||||||
|
{
|
||||||
|
Token: getPointer("valid"),
|
||||||
|
UsesAllowed: getPointer(int32(10)),
|
||||||
|
ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
|
||||||
|
Pending: getPointer(int32(0)),
|
||||||
|
Completed: getPointer(int32(0)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Token: getPointer("invalid"),
|
||||||
|
UsesAllowed: getPointer(int32(10)),
|
||||||
|
ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
|
||||||
|
Pending: getPointer(int32(0)),
|
||||||
|
Completed: getPointer(int32(0)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tkn := range tokens {
|
||||||
|
tkn := tkn
|
||||||
|
userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
|
||||||
|
}
|
||||||
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
requestingUser *test.User
|
||||||
|
valid string
|
||||||
|
isValidSpecified bool
|
||||||
|
wantOK bool
|
||||||
|
withHeader bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Missing auth",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
isValidSpecified: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bob is denied access",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
isValidSpecified: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can list all tokens",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can list all valid tokens",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
valid: "true",
|
||||||
|
isValidSpecified: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can list all invalid tokens",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
valid: "false",
|
||||||
|
isValidSpecified: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No response when valid has a bad value",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
valid: "trueee",
|
||||||
|
isValidSpecified: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var path string
|
||||||
|
if tc.isValidSpecified {
|
||||||
|
path = fmt.Sprintf("/_dendrite/admin/registrationTokens?valid=%v", tc.valid)
|
||||||
|
} else {
|
||||||
|
path = "/_dendrite/admin/registrationTokens"
|
||||||
|
}
|
||||||
|
req := test.NewRequest(t, http.MethodGet, path)
|
||||||
|
if tc.withHeader {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
routers.DendriteAdmin.ServeHTTP(rec, req)
|
||||||
|
t.Logf("%s", rec.Body.String())
|
||||||
|
if tc.wantOK && rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminGetRegistrationToken(t *testing.T) {
|
||||||
|
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
|
||||||
|
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
cfg.ClientAPI.RegistrationRequiresToken = true
|
||||||
|
defer close()
|
||||||
|
natsInstance := jetstream.NATSInstance{}
|
||||||
|
routers := httputil.NewRouters()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||||
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||||
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||||
|
accessTokens := map[*test.User]userDevice{
|
||||||
|
aliceAdmin: {},
|
||||||
|
bob: {},
|
||||||
|
}
|
||||||
|
tokens := []capi.RegistrationToken{
|
||||||
|
{
|
||||||
|
Token: getPointer("alice_token1"),
|
||||||
|
UsesAllowed: getPointer(int32(10)),
|
||||||
|
ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
|
||||||
|
Pending: getPointer(int32(0)),
|
||||||
|
Completed: getPointer(int32(0)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Token: getPointer("alice_token2"),
|
||||||
|
UsesAllowed: getPointer(int32(10)),
|
||||||
|
ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
|
||||||
|
Pending: getPointer(int32(0)),
|
||||||
|
Completed: getPointer(int32(0)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tkn := range tokens {
|
||||||
|
tkn := tkn
|
||||||
|
userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
|
||||||
|
}
|
||||||
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
requestingUser *test.User
|
||||||
|
token string
|
||||||
|
wantOK bool
|
||||||
|
withHeader bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Missing auth",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bob is denied access",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can GET alice_token1",
|
||||||
|
token: "alice_token1",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can GET alice_token2",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice cannot GET a token that does not exists",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token3",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token)
|
||||||
|
req := test.NewRequest(t, http.MethodGet, path)
|
||||||
|
if tc.withHeader {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
routers.DendriteAdmin.ServeHTTP(rec, req)
|
||||||
|
t.Logf("%s", rec.Body.String())
|
||||||
|
if tc.wantOK && rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminDeleteRegistrationToken(t *testing.T) {
|
||||||
|
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
|
||||||
|
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
cfg.ClientAPI.RegistrationRequiresToken = true
|
||||||
|
defer close()
|
||||||
|
natsInstance := jetstream.NATSInstance{}
|
||||||
|
routers := httputil.NewRouters()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||||
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||||
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||||
|
accessTokens := map[*test.User]userDevice{
|
||||||
|
aliceAdmin: {},
|
||||||
|
bob: {},
|
||||||
|
}
|
||||||
|
tokens := []capi.RegistrationToken{
|
||||||
|
{
|
||||||
|
Token: getPointer("alice_token1"),
|
||||||
|
UsesAllowed: getPointer(int32(10)),
|
||||||
|
ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
|
||||||
|
Pending: getPointer(int32(0)),
|
||||||
|
Completed: getPointer(int32(0)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Token: getPointer("alice_token2"),
|
||||||
|
UsesAllowed: getPointer(int32(10)),
|
||||||
|
ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
|
||||||
|
Pending: getPointer(int32(0)),
|
||||||
|
Completed: getPointer(int32(0)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tkn := range tokens {
|
||||||
|
tkn := tkn
|
||||||
|
userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
|
||||||
|
}
|
||||||
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
requestingUser *test.User
|
||||||
|
token string
|
||||||
|
wantOK bool
|
||||||
|
withHeader bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Missing auth",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bob is denied access",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can DELETE alice_token1",
|
||||||
|
token: "alice_token1",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can DELETE alice_token2",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token)
|
||||||
|
req := test.NewRequest(t, http.MethodDelete, path)
|
||||||
|
if tc.withHeader {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
routers.DendriteAdmin.ServeHTTP(rec, req)
|
||||||
|
t.Logf("%s", rec.Body.String())
|
||||||
|
if tc.wantOK && rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUpdateRegistrationToken(t *testing.T) {
|
||||||
|
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
|
||||||
|
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
cfg.ClientAPI.RegistrationRequiresToken = true
|
||||||
|
defer close()
|
||||||
|
natsInstance := jetstream.NATSInstance{}
|
||||||
|
routers := httputil.NewRouters()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||||
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||||
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||||
|
accessTokens := map[*test.User]userDevice{
|
||||||
|
aliceAdmin: {},
|
||||||
|
bob: {},
|
||||||
|
}
|
||||||
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
||||||
|
tokens := []capi.RegistrationToken{
|
||||||
|
{
|
||||||
|
Token: getPointer("alice_token1"),
|
||||||
|
UsesAllowed: getPointer(int32(10)),
|
||||||
|
ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
|
||||||
|
Pending: getPointer(int32(0)),
|
||||||
|
Completed: getPointer(int32(0)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Token: getPointer("alice_token2"),
|
||||||
|
UsesAllowed: getPointer(int32(10)),
|
||||||
|
ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
|
||||||
|
Pending: getPointer(int32(0)),
|
||||||
|
Completed: getPointer(int32(0)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tkn := range tokens {
|
||||||
|
tkn := tkn
|
||||||
|
userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
|
||||||
|
}
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
requestingUser *test.User
|
||||||
|
method string
|
||||||
|
token string
|
||||||
|
requestOpt test.HTTPRequestOpt
|
||||||
|
wantOK bool
|
||||||
|
withHeader bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Missing auth",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
token: "alice_token1",
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"uses_allowed": 10,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bob is denied access",
|
||||||
|
requestingUser: bob,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token1",
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"uses_allowed": 10,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can UPDATE a token's uses_allowed property",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token1",
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"uses_allowed": 10,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can UPDATE a token's expiry_time property",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: true,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token2",
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can UPDATE a token's uses_allowed and expiry_time property",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token1",
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"uses_allowed": 20,
|
||||||
|
"expiry_time": time.Now().Add(10*24*time.Hour).UnixNano() / int64(time.Millisecond),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice CANNOT update a token with invalid properties",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token2",
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"uses_allowed": -5,
|
||||||
|
"expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice CANNOT UPDATE a token that does not exist",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token9",
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"uses_allowed": 100,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can UPDATE token specifying uses_allowed as null - Valid for infinite uses",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token1",
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"uses_allowed": nil,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alice can UPDATE token specifying expiry_time AS null - Valid for infinite time",
|
||||||
|
requestingUser: aliceAdmin,
|
||||||
|
wantOK: false,
|
||||||
|
withHeader: true,
|
||||||
|
token: "alice_token1",
|
||||||
|
requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"expiry_time": nil,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token)
|
||||||
|
req := test.NewRequest(t, http.MethodPut, path)
|
||||||
|
if tc.requestOpt != nil {
|
||||||
|
req = test.NewRequest(t, http.MethodPut, path, tc.requestOpt)
|
||||||
|
}
|
||||||
|
if tc.withHeader {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
routers.DendriteAdmin.ServeHTTP(rec, req)
|
||||||
|
t.Logf("%s", rec.Body.String())
|
||||||
|
if tc.wantOK && rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPointer[T any](s T) *T {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdminResetPassword(t *testing.T) {
|
func TestAdminResetPassword(t *testing.T) {
|
||||||
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
|
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
|
||||||
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||||
|
|
|
@ -21,3 +21,11 @@ type ExtraPublicRoomsProvider interface {
|
||||||
// Rooms returns the extra rooms. This is called on-demand by clients, so cache appropriately.
|
// Rooms returns the extra rooms. This is called on-demand by clients, so cache appropriately.
|
||||||
Rooms() []fclient.PublicRoom
|
Rooms() []fclient.PublicRoom
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RegistrationToken struct {
|
||||||
|
Token *string `json:"token"`
|
||||||
|
UsesAllowed *int32 `json:"uses_allowed"`
|
||||||
|
Pending *int32 `json:"pending"`
|
||||||
|
Completed *int32 `json:"completed"`
|
||||||
|
ExpiryTime *int64 `json:"expiry_time"`
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
@ -16,14 +18,254 @@ import (
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/constraints"
|
||||||
|
|
||||||
|
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
|
||||||
|
|
||||||
|
func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
|
if !cfg.RegistrationRequiresToken {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request := struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
UsesAllowed *int32 `json:"uses_allowed,omitempty"`
|
||||||
|
ExpiryTime *int64 `json:"expiry_time,omitempty"`
|
||||||
|
Length int32 `json:"length"`
|
||||||
|
}{}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
token := request.Token
|
||||||
|
usesAllowed := request.UsesAllowed
|
||||||
|
expiryTime := request.ExpiryTime
|
||||||
|
length := request.Length
|
||||||
|
|
||||||
|
if len(token) == 0 {
|
||||||
|
if length == 0 {
|
||||||
|
// length not provided in request. Assign default value of 16.
|
||||||
|
length = 16
|
||||||
|
}
|
||||||
|
// token not present in request body. Hence, generate a random token.
|
||||||
|
if length <= 0 || length > 64 {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("length must be greater than zero and not greater than 64"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
token = util.RandomString(int(length))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(token) > 64 {
|
||||||
|
//Token present in request body, but is too long.
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("token must not be longer than 64"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isTokenValid := validRegistrationTokenRegex.Match([]byte(token))
|
||||||
|
if !isTokenValid {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// At this point, we have a valid token, either through request body or through random generation.
|
||||||
|
if usesAllowed != nil && *usesAllowed < 0 {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("expiry_time must not be in the past"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pending := int32(0)
|
||||||
|
completed := int32(0)
|
||||||
|
// If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB)
|
||||||
|
registrationToken := &clientapi.RegistrationToken{
|
||||||
|
Token: &token,
|
||||||
|
UsesAllowed: usesAllowed,
|
||||||
|
Pending: &pending,
|
||||||
|
Completed: &completed,
|
||||||
|
ExpiryTime: expiryTime,
|
||||||
|
}
|
||||||
|
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken)
|
||||||
|
if !created {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusConflict,
|
||||||
|
JSON: map[string]string{
|
||||||
|
"error": fmt.Sprintf("token: %s already exists", token),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: map[string]interface{}{
|
||||||
|
"token": token,
|
||||||
|
"uses_allowed": getReturnValue(usesAllowed),
|
||||||
|
"pending": pending,
|
||||||
|
"completed": completed,
|
||||||
|
"expiry_time": getReturnValue(expiryTime),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getReturnValue[t constraints.Integer](in *t) any {
|
||||||
|
if in == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *in
|
||||||
|
}
|
||||||
|
|
||||||
|
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
|
queryParams := req.URL.Query()
|
||||||
|
returnAll := true
|
||||||
|
valid := true
|
||||||
|
validQuery, ok := queryParams["valid"]
|
||||||
|
if ok {
|
||||||
|
returnAll = false
|
||||||
|
validValue, err := strconv.ParseBool(validQuery[0])
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("invalid 'valid' query parameter"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
valid = validValue
|
||||||
|
}
|
||||||
|
tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.ErrorUnknown,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: map[string]interface{}{
|
||||||
|
"registration_tokens": tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
tokenText := vars["token"]
|
||||||
|
token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
tokenText := vars["token"]
|
||||||
|
err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
tokenText := vars["token"]
|
||||||
|
request := make(map[string]*int64)
|
||||||
|
if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newAttributes := make(map[string]interface{})
|
||||||
|
usesAllowed, ok := request["uses_allowed"]
|
||||||
|
if ok {
|
||||||
|
// Only add usesAllowed to newAtrributes if it is present and valid
|
||||||
|
if usesAllowed != nil && *usesAllowed < 0 {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newAttributes["usesAllowed"] = usesAllowed
|
||||||
|
}
|
||||||
|
expiryTime, ok := request["expiry_time"]
|
||||||
|
if ok {
|
||||||
|
// Only add expiryTime to newAtrributes if it is present and valid
|
||||||
|
if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("expiry_time must not be in the past"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newAttributes["expiryTime"] = expiryTime
|
||||||
|
}
|
||||||
|
if len(newAttributes) == 0 {
|
||||||
|
// No attributes to update. Return existing token
|
||||||
|
return AdminGetRegistrationToken(req, cfg, userAPI)
|
||||||
|
}
|
||||||
|
updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: *updatedToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
|
func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -162,6 +162,36 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
}
|
}
|
||||||
|
dendriteAdminRouter.Handle("/admin/registrationTokens/new",
|
||||||
|
httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return AdminCreateNewRegistrationToken(req, cfg, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
dendriteAdminRouter.Handle("/admin/registrationTokens",
|
||||||
|
httputil.MakeAdminAPI("admin_list_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return AdminListRegistrationTokens(req, cfg, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
dendriteAdminRouter.Handle("/admin/registrationTokens/{token}",
|
||||||
|
httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
switch req.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
return AdminGetRegistrationToken(req, cfg, userAPI)
|
||||||
|
case http.MethodPut:
|
||||||
|
return AdminUpdateRegistrationToken(req, cfg, userAPI)
|
||||||
|
case http.MethodDelete:
|
||||||
|
return AdminDeleteRegistrationToken(req, cfg, userAPI)
|
||||||
|
default:
|
||||||
|
return util.MatrixErrorResponse(
|
||||||
|
404,
|
||||||
|
string(spec.ErrorNotFound),
|
||||||
|
"unknown method",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions)
|
||||||
|
|
||||||
dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}",
|
dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}",
|
||||||
httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
|
|
@ -13,6 +13,10 @@ type ClientAPI struct {
|
||||||
// secrets)
|
// secrets)
|
||||||
RegistrationDisabled bool `yaml:"registration_disabled"`
|
RegistrationDisabled bool `yaml:"registration_disabled"`
|
||||||
|
|
||||||
|
// If set, requires users to submit a token during registration.
|
||||||
|
// Tokens can be managed using admin API.
|
||||||
|
RegistrationRequiresToken bool `yaml:"registration_requires_token"`
|
||||||
|
|
||||||
// Enable registration without captcha verification or shared secret.
|
// Enable registration without captcha verification or shared secret.
|
||||||
// This option is populated by the -really-enable-open-registration
|
// This option is populated by the -really-enable-open-registration
|
||||||
// command line parameter as it is not recommended.
|
// command line parameter as it is not recommended.
|
||||||
|
@ -56,6 +60,7 @@ type ClientAPI struct {
|
||||||
|
|
||||||
func (c *ClientAPI) Defaults(opts DefaultOpts) {
|
func (c *ClientAPI) Defaults(opts DefaultOpts) {
|
||||||
c.RegistrationSharedSecret = ""
|
c.RegistrationSharedSecret = ""
|
||||||
|
c.RegistrationRequiresToken = false
|
||||||
c.RecaptchaPublicKey = ""
|
c.RecaptchaPublicKey = ""
|
||||||
c.RecaptchaPrivateKey = ""
|
c.RecaptchaPrivateKey = ""
|
||||||
c.RecaptchaEnabled = false
|
c.RecaptchaEnabled = false
|
||||||
|
|
|
@ -27,6 +27,7 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
|
||||||
|
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
)
|
)
|
||||||
|
@ -94,6 +95,11 @@ type ClientUserAPI interface {
|
||||||
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
|
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
|
||||||
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
|
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
|
||||||
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
|
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
|
||||||
|
PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
|
||||||
|
PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||||
|
PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
|
||||||
|
PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error
|
||||||
|
PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
|
||||||
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
||||||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||||
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||||
|
|
|
@ -33,6 +33,7 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||||
|
@ -63,6 +64,37 @@ type UserInternalAPI struct {
|
||||||
Updater *DeviceListUpdater
|
Updater *DeviceListUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) {
|
||||||
|
exists, err := a.DB.RegistrationTokenExists(ctx, *registrationToken.Token)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return false, fmt.Errorf("token: %s already exists", *registrationToken.Token)
|
||||||
|
}
|
||||||
|
_, err = a.DB.InsertRegistrationToken(ctx, registrationToken)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("Error creating token: %s"+err.Error(), *registrationToken.Token)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
|
||||||
|
return a.DB.ListRegistrationTokens(ctx, returnAll, valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
|
||||||
|
return a.DB.GetRegistrationToken(ctx, tokenString)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error {
|
||||||
|
return a.DB.DeleteRegistrationToken(ctx, tokenString)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) {
|
||||||
|
return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||||
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
|
||||||
|
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
@ -30,6 +31,15 @@ import (
|
||||||
"github.com/matrix-org/dendrite/userapi/types"
|
"github.com/matrix-org/dendrite/userapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RegistrationTokens interface {
|
||||||
|
RegistrationTokenExists(ctx context.Context, token string) (bool, error)
|
||||||
|
InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
|
||||||
|
ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||||
|
GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
|
||||||
|
DeleteRegistrationToken(ctx context.Context, tokenString string) error
|
||||||
|
UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
|
||||||
|
}
|
||||||
|
|
||||||
type Profile interface {
|
type Profile interface {
|
||||||
GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error)
|
GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error)
|
||||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
|
@ -144,6 +154,7 @@ type UserDatabase interface {
|
||||||
Pusher
|
Pusher
|
||||||
Statistics
|
Statistics
|
||||||
ThreePID
|
ThreePID
|
||||||
|
RegistrationTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyChangeDatabase interface {
|
type KeyChangeDatabase interface {
|
||||||
|
|
222
userapi/storage/postgres/registration_tokens_table.go
Normal file
222
userapi/storage/postgres/registration_tokens_table.go
Normal file
|
@ -0,0 +1,222 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/api"
|
||||||
|
internal "github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"golang.org/x/exp/constraints"
|
||||||
|
)
|
||||||
|
|
||||||
|
const registrationTokensSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
|
||||||
|
token TEXT PRIMARY KEY,
|
||||||
|
pending BIGINT,
|
||||||
|
completed BIGINT,
|
||||||
|
uses_allowed BIGINT,
|
||||||
|
expiry_time BIGINT
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectTokenSQL = "" +
|
||||||
|
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
|
||||||
|
|
||||||
|
const insertTokenSQL = "" +
|
||||||
|
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
|
const listAllTokensSQL = "" +
|
||||||
|
"SELECT * FROM userapi_registration_tokens"
|
||||||
|
|
||||||
|
const listValidTokensSQL = "" +
|
||||||
|
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||||
|
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
|
||||||
|
"(expiry_time > $1 OR expiry_time IS NULL)"
|
||||||
|
|
||||||
|
const listInvalidTokensSQL = "" +
|
||||||
|
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||||
|
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
|
||||||
|
|
||||||
|
const getTokenSQL = "" +
|
||||||
|
"SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
|
||||||
|
|
||||||
|
const deleteTokenSQL = "" +
|
||||||
|
"DELETE FROM userapi_registration_tokens WHERE token = $1"
|
||||||
|
|
||||||
|
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
|
||||||
|
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
|
||||||
|
|
||||||
|
const updateTokenUsesAllowedSQL = "" +
|
||||||
|
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
|
||||||
|
|
||||||
|
const updateTokenExpiryTimeSQL = "" +
|
||||||
|
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
|
||||||
|
|
||||||
|
type registrationTokenStatements struct {
|
||||||
|
selectTokenStatement *sql.Stmt
|
||||||
|
insertTokenStatement *sql.Stmt
|
||||||
|
listAllTokensStatement *sql.Stmt
|
||||||
|
listValidTokensStatement *sql.Stmt
|
||||||
|
listInvalidTokenStatement *sql.Stmt
|
||||||
|
getTokenStatement *sql.Stmt
|
||||||
|
deleteTokenStatement *sql.Stmt
|
||||||
|
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
|
||||||
|
updateTokenUsesAllowedStatement *sql.Stmt
|
||||||
|
updateTokenExpiryTimeStatement *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
||||||
|
s := ®istrationTokenStatements{}
|
||||||
|
_, err := db.Exec(registrationTokensSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.selectTokenStatement, selectTokenSQL},
|
||||||
|
{&s.insertTokenStatement, insertTokenSQL},
|
||||||
|
{&s.listAllTokensStatement, listAllTokensSQL},
|
||||||
|
{&s.listValidTokensStatement, listValidTokensSQL},
|
||||||
|
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
|
||||||
|
{&s.getTokenStatement, getTokenSQL},
|
||||||
|
{&s.deleteTokenStatement, deleteTokenSQL},
|
||||||
|
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
|
||||||
|
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
|
||||||
|
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
|
||||||
|
var existingToken string
|
||||||
|
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
|
||||||
|
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
|
||||||
|
stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
|
||||||
|
_, err := stmt.ExecContext(
|
||||||
|
ctx,
|
||||||
|
*registrationToken.Token,
|
||||||
|
getInsertValue(registrationToken.UsesAllowed),
|
||||||
|
getInsertValue(registrationToken.ExpiryTime),
|
||||||
|
*registrationToken.Pending,
|
||||||
|
*registrationToken.Completed)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInsertValue[t constraints.Integer](in *t) any {
|
||||||
|
if in == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *in
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
var tokens []api.RegistrationToken
|
||||||
|
var tokenString string
|
||||||
|
var pending, completed, usesAllowed *int32
|
||||||
|
var expiryTime *int64
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if returnAll {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
|
||||||
|
rows, err = stmt.QueryContext(ctx)
|
||||||
|
} else if valid {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
|
||||||
|
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||||
|
} else {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
|
||||||
|
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return tokens, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return tokens, err
|
||||||
|
}
|
||||||
|
tokenString := tokenString
|
||||||
|
pending := pending
|
||||||
|
completed := completed
|
||||||
|
usesAllowed := usesAllowed
|
||||||
|
expiryTime := expiryTime
|
||||||
|
|
||||||
|
tokenMap := api.RegistrationToken{
|
||||||
|
Token: &tokenString,
|
||||||
|
Pending: pending,
|
||||||
|
Completed: completed,
|
||||||
|
UsesAllowed: usesAllowed,
|
||||||
|
ExpiryTime: expiryTime,
|
||||||
|
}
|
||||||
|
tokens = append(tokens, tokenMap)
|
||||||
|
}
|
||||||
|
return tokens, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
|
||||||
|
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
|
||||||
|
var pending, completed, usesAllowed *int32
|
||||||
|
var expiryTime *int64
|
||||||
|
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
token := api.RegistrationToken{
|
||||||
|
Token: &tokenString,
|
||||||
|
Pending: pending,
|
||||||
|
Completed: completed,
|
||||||
|
UsesAllowed: usesAllowed,
|
||||||
|
ExpiryTime: expiryTime,
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
|
||||||
|
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
|
||||||
|
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
|
||||||
|
if usesAllowedPresent && expiryTimePresent {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if usesAllowedPresent {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if expiryTimePresent {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.GetRegistrationToken(ctx, tx, tokenString)
|
||||||
|
}
|
|
@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
registationTokensTable, err := NewPostgresRegistrationTokensTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err)
|
||||||
|
}
|
||||||
accountsTable, err := NewPostgresAccountsTable(db, serverName)
|
accountsTable, err := NewPostgresAccountsTable(db, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
|
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
|
||||||
|
@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
|
||||||
ThreePIDs: threePIDTable,
|
ThreePIDs: threePIDTable,
|
||||||
Pushers: pusherTable,
|
Pushers: pusherTable,
|
||||||
Notifications: notificationsTable,
|
Notifications: notificationsTable,
|
||||||
|
RegistrationTokens: registationTokensTable,
|
||||||
Stats: statsTable,
|
Stats: statsTable,
|
||||||
ServerName: serverName,
|
ServerName: serverName,
|
||||||
DB: db,
|
DB: db,
|
||||||
|
|
|
@ -31,6 +31,7 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
@ -43,6 +44,7 @@ import (
|
||||||
type Database struct {
|
type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
Writer sqlutil.Writer
|
Writer sqlutil.Writer
|
||||||
|
RegistrationTokens tables.RegistrationTokensTable
|
||||||
Accounts tables.AccountsTable
|
Accounts tables.AccountsTable
|
||||||
Profiles tables.ProfileTable
|
Profiles tables.ProfileTable
|
||||||
AccountDatas tables.AccountDataTable
|
AccountDatas tables.AccountDataTable
|
||||||
|
@ -78,6 +80,42 @@ const (
|
||||||
loginTokenByteLength = 32
|
loginTokenByteLength = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) {
|
||||||
|
return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (created bool, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, registrationToken)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
|
||||||
|
return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
|
||||||
|
return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
func (d *Database) GetAccountByPassword(
|
func (d *Database) GetAccountByPassword(
|
||||||
|
|
222
userapi/storage/sqlite3/registration_tokens_table.go
Normal file
222
userapi/storage/sqlite3/registration_tokens_table.go
Normal file
|
@ -0,0 +1,222 @@
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/api"
|
||||||
|
internal "github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"golang.org/x/exp/constraints"
|
||||||
|
)
|
||||||
|
|
||||||
|
const registrationTokensSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
|
||||||
|
token TEXT PRIMARY KEY,
|
||||||
|
pending BIGINT,
|
||||||
|
completed BIGINT,
|
||||||
|
uses_allowed BIGINT,
|
||||||
|
expiry_time BIGINT
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectTokenSQL = "" +
|
||||||
|
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
|
||||||
|
|
||||||
|
const insertTokenSQL = "" +
|
||||||
|
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
|
const listAllTokensSQL = "" +
|
||||||
|
"SELECT * FROM userapi_registration_tokens"
|
||||||
|
|
||||||
|
const listValidTokensSQL = "" +
|
||||||
|
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||||
|
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
|
||||||
|
"(expiry_time > $1 OR expiry_time IS NULL)"
|
||||||
|
|
||||||
|
const listInvalidTokensSQL = "" +
|
||||||
|
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||||
|
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
|
||||||
|
|
||||||
|
const getTokenSQL = "" +
|
||||||
|
"SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
|
||||||
|
|
||||||
|
const deleteTokenSQL = "" +
|
||||||
|
"DELETE FROM userapi_registration_tokens WHERE token = $1"
|
||||||
|
|
||||||
|
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
|
||||||
|
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
|
||||||
|
|
||||||
|
const updateTokenUsesAllowedSQL = "" +
|
||||||
|
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
|
||||||
|
|
||||||
|
const updateTokenExpiryTimeSQL = "" +
|
||||||
|
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
|
||||||
|
|
||||||
|
type registrationTokenStatements struct {
|
||||||
|
selectTokenStatement *sql.Stmt
|
||||||
|
insertTokenStatement *sql.Stmt
|
||||||
|
listAllTokensStatement *sql.Stmt
|
||||||
|
listValidTokensStatement *sql.Stmt
|
||||||
|
listInvalidTokenStatement *sql.Stmt
|
||||||
|
getTokenStatement *sql.Stmt
|
||||||
|
deleteTokenStatement *sql.Stmt
|
||||||
|
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
|
||||||
|
updateTokenUsesAllowedStatement *sql.Stmt
|
||||||
|
updateTokenExpiryTimeStatement *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
||||||
|
s := ®istrationTokenStatements{}
|
||||||
|
_, err := db.Exec(registrationTokensSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.selectTokenStatement, selectTokenSQL},
|
||||||
|
{&s.insertTokenStatement, insertTokenSQL},
|
||||||
|
{&s.listAllTokensStatement, listAllTokensSQL},
|
||||||
|
{&s.listValidTokensStatement, listValidTokensSQL},
|
||||||
|
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
|
||||||
|
{&s.getTokenStatement, getTokenSQL},
|
||||||
|
{&s.deleteTokenStatement, deleteTokenSQL},
|
||||||
|
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
|
||||||
|
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
|
||||||
|
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
|
||||||
|
var existingToken string
|
||||||
|
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
|
||||||
|
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
|
||||||
|
stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
|
||||||
|
_, err := stmt.ExecContext(
|
||||||
|
ctx,
|
||||||
|
*registrationToken.Token,
|
||||||
|
getInsertValue(registrationToken.UsesAllowed),
|
||||||
|
getInsertValue(registrationToken.ExpiryTime),
|
||||||
|
*registrationToken.Pending,
|
||||||
|
*registrationToken.Completed)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInsertValue[t constraints.Integer](in *t) any {
|
||||||
|
if in == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *in
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
var tokens []api.RegistrationToken
|
||||||
|
var tokenString string
|
||||||
|
var pending, completed, usesAllowed *int32
|
||||||
|
var expiryTime *int64
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if returnAll {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
|
||||||
|
rows, err = stmt.QueryContext(ctx)
|
||||||
|
} else if valid {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
|
||||||
|
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||||
|
} else {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
|
||||||
|
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return tokens, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return tokens, err
|
||||||
|
}
|
||||||
|
tokenString := tokenString
|
||||||
|
pending := pending
|
||||||
|
completed := completed
|
||||||
|
usesAllowed := usesAllowed
|
||||||
|
expiryTime := expiryTime
|
||||||
|
|
||||||
|
tokenMap := api.RegistrationToken{
|
||||||
|
Token: &tokenString,
|
||||||
|
Pending: pending,
|
||||||
|
Completed: completed,
|
||||||
|
UsesAllowed: usesAllowed,
|
||||||
|
ExpiryTime: expiryTime,
|
||||||
|
}
|
||||||
|
tokens = append(tokens, tokenMap)
|
||||||
|
}
|
||||||
|
return tokens, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
|
||||||
|
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
|
||||||
|
var pending, completed, usesAllowed *int32
|
||||||
|
var expiryTime *int64
|
||||||
|
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
token := api.RegistrationToken{
|
||||||
|
Token: &tokenString,
|
||||||
|
Pending: pending,
|
||||||
|
Completed: completed,
|
||||||
|
UsesAllowed: usesAllowed,
|
||||||
|
ExpiryTime: expiryTime,
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
|
||||||
|
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
|
||||||
|
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
|
||||||
|
if usesAllowedPresent && expiryTimePresent {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if usesAllowedPresent {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if expiryTimePresent {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.GetRegistrationToken(ctx, tx, tokenString)
|
||||||
|
}
|
|
@ -50,7 +50,10 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti
|
||||||
if err = m.Up(ctx); err != nil {
|
if err = m.Up(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
registationTokensTable, err := NewSQLiteRegistrationTokensTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("NewSQLiteRegistrationsTokenTable: %w", err)
|
||||||
|
}
|
||||||
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
||||||
|
@ -130,6 +133,7 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti
|
||||||
LoginTokenLifetime: loginTokenLifetime,
|
LoginTokenLifetime: loginTokenLifetime,
|
||||||
BcryptCost: bcryptCost,
|
BcryptCost: bcryptCost,
|
||||||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||||
|
RegistrationTokens: registationTokensTable,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,10 +25,20 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
|
||||||
|
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/userapi/types"
|
"github.com/matrix-org/dendrite/userapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RegistrationTokensTable interface {
|
||||||
|
RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error)
|
||||||
|
InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error)
|
||||||
|
ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||||
|
GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error)
|
||||||
|
DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error
|
||||||
|
UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
|
||||||
|
}
|
||||||
|
|
||||||
type AccountDataTable interface {
|
type AccountDataTable interface {
|
||||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error
|
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||||
SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||||
|
|
Loading…
Reference in a new issue