mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-29 01:33:10 -06:00
Implement threepid sessions methods for UserInternalAPI
This commit is contained in:
parent
1f537ba1b9
commit
7a98afd072
15
internal/random.go
Normal file
15
internal/random.go
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
func GenerateBlob(blobLen int) (string, error) {
|
||||
b := make([]byte, blobLen)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// url-safe no padding
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
|
@ -555,6 +555,22 @@ func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOp
|
|||
return nil
|
||||
}
|
||||
|
||||
func (u *testUserAPI) CreateSession(context.Context, *userapi.CreateSessionRequest, *userapi.CreateSessionResponse) error {
|
||||
return nil
|
||||
}
|
||||
func (u *testUserAPI) ValidateSession(context.Context, *userapi.ValidateSessionRequest, struct{}) error {
|
||||
return nil
|
||||
}
|
||||
func (u *testUserAPI) GetThreePidForSession(context.Context, *userapi.SessionOwnership, *userapi.GetThreePidForSessionResponse) error {
|
||||
return nil
|
||||
}
|
||||
func (u *testUserAPI) DeleteSession(context.Context, *userapi.SessionOwnership, struct{}) error {
|
||||
return nil
|
||||
}
|
||||
func (u *testUserAPI) IsSessionValidated(context.Context, *userapi.SessionOwnership, *userapi.IsSessionValidatedResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testRoomserverAPI struct {
|
||||
// use a trace API as it implements method stubs so we don't need to have them here.
|
||||
// We'll override the functions we care about.
|
||||
|
|
|
|||
|
|
@ -398,6 +398,22 @@ func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOp
|
|||
return nil
|
||||
}
|
||||
|
||||
func (u *testUserAPI) CreateSession(context.Context, *userapi.CreateSessionRequest, *userapi.CreateSessionResponse) error {
|
||||
return nil
|
||||
}
|
||||
func (u *testUserAPI) ValidateSession(context.Context, *userapi.ValidateSessionRequest, struct{}) error {
|
||||
return nil
|
||||
}
|
||||
func (u *testUserAPI) GetThreePidForSession(context.Context, *userapi.SessionOwnership, *userapi.GetThreePidForSessionResponse) error {
|
||||
return nil
|
||||
}
|
||||
func (u *testUserAPI) DeleteSession(context.Context, *userapi.SessionOwnership, struct{}) error {
|
||||
return nil
|
||||
}
|
||||
func (u *testUserAPI) IsSessionValidated(context.Context, *userapi.SessionOwnership, *userapi.IsSessionValidatedResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testRoomserverAPI struct {
|
||||
// use a trace API as it implements method stubs so we don't need to have them here.
|
||||
// We'll override the functions we care about.
|
||||
|
|
|
|||
|
|
@ -27,8 +27,10 @@ import (
|
|||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/mail"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/threepid"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
|
@ -37,10 +39,12 @@ import (
|
|||
type UserInternalAPI struct {
|
||||
AccountDB accounts.Database
|
||||
DeviceDB devices.Database
|
||||
ThreePidDB threepid.Database
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
// AppServices is the list of all registered AS
|
||||
AppServices []config.ApplicationService
|
||||
KeyAPI keyapi.KeyInternalAPI
|
||||
Mail mail.Mailer
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||
|
|
|
|||
121
userapi/internal/threepid.go
Normal file
121
userapi/internal/threepid.go
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/mail"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/threepid"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionIdByteLength = 32
|
||||
tokenByteLength = 48
|
||||
)
|
||||
|
||||
var ErrBadSession = errors.New("provided sid, client_secret and token does not point to valid session")
|
||||
|
||||
func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSessionRequest, res *api.CreateSessionResponse) error {
|
||||
s, err := a.ThreePidDB.GetSessionByThreePidAndSecret(ctx, req.ThreePid, req.ClientSecret)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
token, err := internal.GenerateBlob(tokenByteLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sid, err := internal.GenerateBlob(sessionIdByteLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s = &api.Session{
|
||||
Sid: sid,
|
||||
ClientSecret: req.ClientSecret,
|
||||
ThreePid: req.ThreePid,
|
||||
SendAttempt: req.SendAttempt,
|
||||
Token: token,
|
||||
NextLink: req.NextLink,
|
||||
}
|
||||
err = a.ThreePidDB.InsertSession(ctx, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if req.SendAttempt > s.SendAttempt {
|
||||
err = a.ThreePidDB.UpdateSendAttemptNextLink(ctx, s.Sid, req.NextLink)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
res.Sid = s.Sid
|
||||
return nil
|
||||
}
|
||||
}
|
||||
res.Sid = s.Sid
|
||||
// TODO - if we fail sending email, send_attempt for next requests must be bumped,
|
||||
// otherwise we will just return nil from this function and not sent email
|
||||
return a.Mail.Send(&mail.Mail{
|
||||
To: s.ThreePid,
|
||||
Link: s.NextLink,
|
||||
Token: s.Token,
|
||||
Extra: req.Extra,
|
||||
}, req.SessionType)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) ValidateSession(ctx context.Context, req *api.ValidateSessionRequest, res struct{}) error {
|
||||
s, err := getSessionByOwnership(ctx, &req.SessionOwnership, a.ThreePidDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.Token != req.Token {
|
||||
return ErrBadSession
|
||||
}
|
||||
return a.ThreePidDB.ValidateSession(ctx, s.Sid, int(time.Now().Unix()))
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error {
|
||||
s, err := getSessionByOwnership(ctx, req, a.ThreePidDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.ThreePid = s.ThreePid
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) DeleteSession(ctx context.Context, req *api.SessionOwnership, res struct{}) error {
|
||||
s, err := getSessionByOwnership(ctx, req, a.ThreePidDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return a.ThreePidDB.DeleteSession(ctx, s.Sid)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) IsSessionValidated(ctx context.Context, req *api.SessionOwnership, res *api.IsSessionValidatedResponse) error {
|
||||
s, err := getSessionByOwnership(ctx, req, a.ThreePidDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Validated = s.Validated
|
||||
res.ValidatedAt = s.ValidatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSessionByOwnership(ctx context.Context, so *api.SessionOwnership, d threepid.Database) (*api.Session, error) {
|
||||
s, err := d.GetSession(ctx, so.Sid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrBadSession
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if s.ClientSecret != so.ClientSecret {
|
||||
return nil, ErrBadSession
|
||||
}
|
||||
return s, err
|
||||
}
|
||||
|
|
@ -44,6 +44,12 @@ const (
|
|||
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
||||
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
|
||||
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
|
||||
|
||||
CreateSessionPath = "/userapi/createSession"
|
||||
ValidateSessionPath = "/userapi/validateSession"
|
||||
GetThreePidForSessionPath = "/userapi/getThreePidForSession"
|
||||
DeleteSessionPath = "/userapi/deleteSession"
|
||||
IsSessionValidatedPath = "/userapi/isSessionValidated"
|
||||
)
|
||||
|
||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||
|
|
@ -225,3 +231,43 @@ func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.Que
|
|||
apiURL := h.apiURL + QueryOpenIDTokenPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSessionRequest, res *api.CreateSessionResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "CreateSession")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + CreateSessionPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) ValidateSession(ctx context.Context, req *api.ValidateSessionRequest, res struct{}) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "ValidateSession")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + ValidateSessionPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "GetThreePidForSession")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + GetThreePidForSessionPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) DeleteSession(ctx context.Context, req *api.SessionOwnership, res struct{}) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "DeleteSession")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + DeleteSessionPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) IsSessionValidated(ctx context.Context, req *api.SessionOwnership, res *api.IsSessionValidatedResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "IsSessionValidated")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + IsSessionValidatedPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,8 +21,10 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
"github.com/matrix-org/dendrite/userapi/mail"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/threepid"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -36,17 +38,26 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
|
|||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||
func NewInternalAPI(
|
||||
accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
|
||||
) api.UserInternalAPI {
|
||||
) *internal.UserInternalAPI {
|
||||
deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to device db")
|
||||
}
|
||||
|
||||
threepidDb, err := threepid.NewDatabase(&cfg.ThreepidDatabase)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to threepid db")
|
||||
}
|
||||
mailer, err := mail.NewMailer(cfg)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to crate Mailer")
|
||||
}
|
||||
return &internal.UserInternalAPI{
|
||||
AccountDB: accountDB,
|
||||
DeviceDB: deviceDB,
|
||||
ThreePidDB: threepidDb,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
AppServices: appServices,
|
||||
KeyAPI: keyAPI,
|
||||
Mail: mailer,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,9 +13,12 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
"github.com/matrix-org/dendrite/userapi/mail"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matryer/is"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
|
|
@ -23,13 +26,37 @@ const (
|
|||
serverName = gomatrixserverlib.ServerName("example.com")
|
||||
)
|
||||
|
||||
func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) {
|
||||
var (
|
||||
testReq = &api.CreateSessionRequest{
|
||||
ClientSecret: "foobar",
|
||||
NextLink: "http://foobar.com",
|
||||
ThreePid: "foo@bar.com",
|
||||
Extra: []string{},
|
||||
SendAttempt: 0,
|
||||
}
|
||||
ctx = context.Background()
|
||||
mailer = &testMailer{
|
||||
c: map[api.ThreepidSessionType]chan *mail.Mail{
|
||||
api.Password: make(chan *mail.Mail, 3),
|
||||
api.Verification: make(chan *mail.Mail, 3),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type testMailer struct {
|
||||
c map[api.ThreepidSessionType]chan *mail.Mail
|
||||
}
|
||||
|
||||
func (tm *testMailer) Send(s *mail.Mail, t api.ThreepidSessionType) error {
|
||||
tm.c[t] <- s
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustMakeInternalAPI(is *is.I) (*internal.UserInternalAPI, accounts.Database) {
|
||||
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
|
||||
ConnectionString: "file::memory:",
|
||||
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create account DB: %s", err)
|
||||
}
|
||||
is.NoErr(err)
|
||||
cfg := &config.UserAPI{
|
||||
DeviceDatabase: config.DatabaseOptions{
|
||||
ConnectionString: "file::memory:",
|
||||
|
|
@ -39,15 +66,24 @@ func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database)
|
|||
Matrix: &config.Global{
|
||||
ServerName: serverName,
|
||||
},
|
||||
ThreepidDatabase: config.DatabaseOptions{
|
||||
ConnectionString: "file::memory:",
|
||||
MaxOpenConnections: 1,
|
||||
MaxIdleConnections: 1,
|
||||
},
|
||||
Email: config.EmailConf{
|
||||
TemplatesPath: "../res/default",
|
||||
},
|
||||
}
|
||||
|
||||
return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB
|
||||
}
|
||||
|
||||
func TestQueryProfile(t *testing.T) {
|
||||
is := is.New(t)
|
||||
aliceAvatarURL := "mxc://example.com/alice"
|
||||
aliceDisplayName := "Alice"
|
||||
userAPI, accountDB := MustMakeInternalAPI(t)
|
||||
userAPI, accountDB := mustMakeInternalAPI(is)
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
|
|
@ -119,3 +155,101 @@ func TestQueryProfile(t *testing.T) {
|
|||
runCases(userAPI)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateSession(t *testing.T) {
|
||||
is := is.New(t)
|
||||
internalApi, _ := mustMakeInternalAPI(is)
|
||||
mustCreateSession(is, internalApi)
|
||||
}
|
||||
|
||||
func TestCreateSession_Twice(t *testing.T) {
|
||||
is := is.New(t)
|
||||
internalApi, _ := mustMakeInternalAPI(is)
|
||||
mustCreateSession(is, internalApi)
|
||||
resp := api.CreateSessionResponse{}
|
||||
err := internalApi.CreateSession(ctx, testReq, &resp)
|
||||
is.NoErr(err)
|
||||
is.Equal(len(resp.Sid), 43)
|
||||
select {
|
||||
case <-mailer.c[api.Verification]:
|
||||
t.Fatal("email was received, but sent attempt was not increased")
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSession_Twice_IncreaseSendAttempt(t *testing.T) {
|
||||
is := is.New(t)
|
||||
internalApi, _ := mustMakeInternalAPI(is)
|
||||
mustCreateSession(is, internalApi)
|
||||
resp := api.CreateSessionResponse{}
|
||||
testReqBumped := *testReq
|
||||
testReqBumped.SendAttempt = 1
|
||||
err := internalApi.CreateSession(ctx, &testReqBumped, &resp)
|
||||
is.NoErr(err)
|
||||
is.Equal(len(resp.Sid), 43)
|
||||
sub := <-mailer.c[api.Verification]
|
||||
is.Equal(len(sub.Token), 64)
|
||||
is.Equal(sub.To, testReq.ThreePid)
|
||||
}
|
||||
|
||||
func TestValidateSession(t *testing.T) {
|
||||
is := is.New(t)
|
||||
internalApi, _ := mustMakeInternalAPI(is)
|
||||
s, token := mustCreateSession(is, internalApi)
|
||||
mustValidateSesson(is, internalApi, testReq.ClientSecret, token, s.Sid)
|
||||
}
|
||||
|
||||
func TestIsSessionValidated_InvalidatedSession(t *testing.T) {
|
||||
is := is.New(t)
|
||||
internalApi, _ := mustMakeInternalAPI(is)
|
||||
s, _ := mustCreateSession(is, internalApi)
|
||||
resp := api.IsSessionValidatedResponse{}
|
||||
err := internalApi.IsSessionValidated(ctx, &api.SessionOwnership{
|
||||
Sid: s.Sid,
|
||||
ClientSecret: testReq.ClientSecret,
|
||||
}, &resp)
|
||||
is.NoErr(err)
|
||||
is.Equal(resp.Validated, false)
|
||||
}
|
||||
|
||||
func TestIsSessionValidated_ValidatedSession(t *testing.T) {
|
||||
is := is.New(t)
|
||||
internalApi, _ := mustMakeInternalAPI(is)
|
||||
s, token := mustCreateSession(is, internalApi)
|
||||
resp := api.IsSessionValidatedResponse{}
|
||||
mustValidateSesson(is, internalApi, testReq.ClientSecret, token, s.Sid)
|
||||
err := internalApi.IsSessionValidated(ctx, &api.SessionOwnership{
|
||||
Sid: s.Sid,
|
||||
ClientSecret: testReq.ClientSecret,
|
||||
}, &resp)
|
||||
is.NoErr(err)
|
||||
is.Equal(resp.Validated, true)
|
||||
is.Equal(resp.ValidatedAt > 0, true)
|
||||
}
|
||||
|
||||
func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateSessionResponse, token string) {
|
||||
resp = &api.CreateSessionResponse{}
|
||||
i.Mail = mailer
|
||||
err := i.CreateSession(ctx, testReq, resp)
|
||||
is.NoErr(err)
|
||||
is.Equal(len(resp.Sid), 43)
|
||||
sub := <-mailer.c[api.Verification]
|
||||
is.Equal(len(sub.Token), 64)
|
||||
is.Equal(sub.To, testReq.ThreePid)
|
||||
token = sub.Token
|
||||
return
|
||||
}
|
||||
|
||||
func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token, sid string) {
|
||||
err := i.ValidateSession(ctx, &api.ValidateSessionRequest{
|
||||
SessionOwnership: api.SessionOwnership{
|
||||
Sid: sid,
|
||||
ClientSecret: secret,
|
||||
},
|
||||
Token: token,
|
||||
},
|
||||
struct{}{},
|
||||
)
|
||||
is.NoErr(err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue