Implement threepid sessions methods for UserInternalAPI

This commit is contained in:
Piotr Kozimor 2021-08-10 10:43:47 +02:00
parent 1f537ba1b9
commit 7a98afd072
8 changed files with 370 additions and 7 deletions

15
internal/random.go Normal file
View file

@ -0,0 +1,15 @@
package internal
import (
"crypto/rand"
"encoding/base64"
)
func GenerateBlob(blobLen int) (string, error) {
b := make([]byte, blobLen)
if _, err := rand.Read(b); err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}

View file

@ -555,6 +555,22 @@ func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOp
return nil
}
func (u *testUserAPI) CreateSession(context.Context, *userapi.CreateSessionRequest, *userapi.CreateSessionResponse) error {
return nil
}
func (u *testUserAPI) ValidateSession(context.Context, *userapi.ValidateSessionRequest, struct{}) error {
return nil
}
func (u *testUserAPI) GetThreePidForSession(context.Context, *userapi.SessionOwnership, *userapi.GetThreePidForSessionResponse) error {
return nil
}
func (u *testUserAPI) DeleteSession(context.Context, *userapi.SessionOwnership, struct{}) error {
return nil
}
func (u *testUserAPI) IsSessionValidated(context.Context, *userapi.SessionOwnership, *userapi.IsSessionValidatedResponse) error {
return nil
}
type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here.
// We'll override the functions we care about.

View file

@ -398,6 +398,22 @@ func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOp
return nil
}
func (u *testUserAPI) CreateSession(context.Context, *userapi.CreateSessionRequest, *userapi.CreateSessionResponse) error {
return nil
}
func (u *testUserAPI) ValidateSession(context.Context, *userapi.ValidateSessionRequest, struct{}) error {
return nil
}
func (u *testUserAPI) GetThreePidForSession(context.Context, *userapi.SessionOwnership, *userapi.GetThreePidForSessionResponse) error {
return nil
}
func (u *testUserAPI) DeleteSession(context.Context, *userapi.SessionOwnership, struct{}) error {
return nil
}
func (u *testUserAPI) IsSessionValidated(context.Context, *userapi.SessionOwnership, *userapi.IsSessionValidatedResponse) error {
return nil
}
type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here.
// We'll override the functions we care about.

View file

@ -27,8 +27,10 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/mail"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/dendrite/userapi/storage/threepid"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@ -37,10 +39,12 @@ import (
type UserInternalAPI struct {
AccountDB accounts.Database
DeviceDB devices.Database
ThreePidDB threepid.Database
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
KeyAPI keyapi.KeyInternalAPI
Mail mail.Mailer
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {

View file

@ -0,0 +1,121 @@
package internal
import (
"context"
"database/sql"
"errors"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/mail"
"github.com/matrix-org/dendrite/userapi/storage/threepid"
)
const (
sessionIdByteLength = 32
tokenByteLength = 48
)
var ErrBadSession = errors.New("provided sid, client_secret and token does not point to valid session")
func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSessionRequest, res *api.CreateSessionResponse) error {
s, err := a.ThreePidDB.GetSessionByThreePidAndSecret(ctx, req.ThreePid, req.ClientSecret)
if err != nil {
if err == sql.ErrNoRows {
token, err := internal.GenerateBlob(tokenByteLength)
if err != nil {
return err
}
sid, err := internal.GenerateBlob(sessionIdByteLength)
if err != nil {
return err
}
s = &api.Session{
Sid: sid,
ClientSecret: req.ClientSecret,
ThreePid: req.ThreePid,
SendAttempt: req.SendAttempt,
Token: token,
NextLink: req.NextLink,
}
err = a.ThreePidDB.InsertSession(ctx, s)
if err != nil {
return err
}
} else {
return err
}
} else {
if req.SendAttempt > s.SendAttempt {
err = a.ThreePidDB.UpdateSendAttemptNextLink(ctx, s.Sid, req.NextLink)
if err != nil {
return err
}
} else {
res.Sid = s.Sid
return nil
}
}
res.Sid = s.Sid
// TODO - if we fail sending email, send_attempt for next requests must be bumped,
// otherwise we will just return nil from this function and not sent email
return a.Mail.Send(&mail.Mail{
To: s.ThreePid,
Link: s.NextLink,
Token: s.Token,
Extra: req.Extra,
}, req.SessionType)
}
func (a *UserInternalAPI) ValidateSession(ctx context.Context, req *api.ValidateSessionRequest, res struct{}) error {
s, err := getSessionByOwnership(ctx, &req.SessionOwnership, a.ThreePidDB)
if err != nil {
return err
}
if s.Token != req.Token {
return ErrBadSession
}
return a.ThreePidDB.ValidateSession(ctx, s.Sid, int(time.Now().Unix()))
}
func (a *UserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error {
s, err := getSessionByOwnership(ctx, req, a.ThreePidDB)
if err != nil {
return err
}
res.ThreePid = s.ThreePid
return nil
}
func (a *UserInternalAPI) DeleteSession(ctx context.Context, req *api.SessionOwnership, res struct{}) error {
s, err := getSessionByOwnership(ctx, req, a.ThreePidDB)
if err != nil {
return err
}
return a.ThreePidDB.DeleteSession(ctx, s.Sid)
}
func (a *UserInternalAPI) IsSessionValidated(ctx context.Context, req *api.SessionOwnership, res *api.IsSessionValidatedResponse) error {
s, err := getSessionByOwnership(ctx, req, a.ThreePidDB)
if err != nil {
return err
}
res.Validated = s.Validated
res.ValidatedAt = s.ValidatedAt
return nil
}
func getSessionByOwnership(ctx context.Context, so *api.SessionOwnership, d threepid.Database) (*api.Session, error) {
s, err := d.GetSession(ctx, so.Sid)
if err != nil {
if err == sql.ErrNoRows {
return nil, ErrBadSession
}
return nil, err
}
if s.ClientSecret != so.ClientSecret {
return nil, ErrBadSession
}
return s, err
}

View file

@ -44,6 +44,12 @@ const (
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
CreateSessionPath = "/userapi/createSession"
ValidateSessionPath = "/userapi/validateSession"
GetThreePidForSessionPath = "/userapi/getThreePidForSession"
DeleteSessionPath = "/userapi/deleteSession"
IsSessionValidatedPath = "/userapi/isSessionValidated"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -225,3 +231,43 @@ func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.Que
apiURL := h.apiURL + QueryOpenIDTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSessionRequest, res *api.CreateSessionResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "CreateSession")
defer span.Finish()
apiURL := h.apiURL + CreateSessionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) ValidateSession(ctx context.Context, req *api.ValidateSessionRequest, res struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "ValidateSession")
defer span.Finish()
apiURL := h.apiURL + ValidateSessionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetThreePidForSession")
defer span.Finish()
apiURL := h.apiURL + GetThreePidForSessionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) DeleteSession(ctx context.Context, req *api.SessionOwnership, res struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "DeleteSession")
defer span.Finish()
apiURL := h.apiURL + DeleteSessionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) IsSessionValidated(ctx context.Context, req *api.SessionOwnership, res *api.IsSessionValidatedResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "IsSessionValidated")
defer span.Finish()
apiURL := h.apiURL + IsSessionValidatedPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -21,8 +21,10 @@ import (
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/mail"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/dendrite/userapi/storage/threepid"
"github.com/sirupsen/logrus"
)
@ -36,17 +38,26 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI {
) *internal.UserInternalAPI {
deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db")
}
threepidDb, err := threepid.NewDatabase(&cfg.ThreepidDatabase)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to threepid db")
}
mailer, err := mail.NewMailer(cfg)
if err != nil {
logrus.WithError(err).Panicf("failed to crate Mailer")
}
return &internal.UserInternalAPI{
AccountDB: accountDB,
DeviceDB: deviceDB,
ThreePidDB: threepidDb,
ServerName: cfg.Matrix.ServerName,
AppServices: appServices,
KeyAPI: keyAPI,
Mail: mailer,
}
}

View file

@ -13,9 +13,12 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/mail"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matryer/is"
"golang.org/x/crypto/bcrypt"
)
@ -23,13 +26,37 @@ const (
serverName = gomatrixserverlib.ServerName("example.com")
)
func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) {
var (
testReq = &api.CreateSessionRequest{
ClientSecret: "foobar",
NextLink: "http://foobar.com",
ThreePid: "foo@bar.com",
Extra: []string{},
SendAttempt: 0,
}
ctx = context.Background()
mailer = &testMailer{
c: map[api.ThreepidSessionType]chan *mail.Mail{
api.Password: make(chan *mail.Mail, 3),
api.Verification: make(chan *mail.Mail, 3),
},
}
)
type testMailer struct {
c map[api.ThreepidSessionType]chan *mail.Mail
}
func (tm *testMailer) Send(s *mail.Mail, t api.ThreepidSessionType) error {
tm.c[t] <- s
return nil
}
func mustMakeInternalAPI(is *is.I) (*internal.UserInternalAPI, accounts.Database) {
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
ConnectionString: "file::memory:",
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS)
if err != nil {
t.Fatalf("failed to create account DB: %s", err)
}
is.NoErr(err)
cfg := &config.UserAPI{
DeviceDatabase: config.DatabaseOptions{
ConnectionString: "file::memory:",
@ -39,15 +66,24 @@ func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database)
Matrix: &config.Global{
ServerName: serverName,
},
ThreepidDatabase: config.DatabaseOptions{
ConnectionString: "file::memory:",
MaxOpenConnections: 1,
MaxIdleConnections: 1,
},
Email: config.EmailConf{
TemplatesPath: "../res/default",
},
}
return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB
}
func TestQueryProfile(t *testing.T) {
is := is.New(t)
aliceAvatarURL := "mxc://example.com/alice"
aliceDisplayName := "Alice"
userAPI, accountDB := MustMakeInternalAPI(t)
userAPI, accountDB := mustMakeInternalAPI(is)
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "")
if err != nil {
t.Fatalf("failed to make account: %s", err)
@ -119,3 +155,101 @@ func TestQueryProfile(t *testing.T) {
runCases(userAPI)
})
}
func TestCreateSession(t *testing.T) {
is := is.New(t)
internalApi, _ := mustMakeInternalAPI(is)
mustCreateSession(is, internalApi)
}
func TestCreateSession_Twice(t *testing.T) {
is := is.New(t)
internalApi, _ := mustMakeInternalAPI(is)
mustCreateSession(is, internalApi)
resp := api.CreateSessionResponse{}
err := internalApi.CreateSession(ctx, testReq, &resp)
is.NoErr(err)
is.Equal(len(resp.Sid), 43)
select {
case <-mailer.c[api.Verification]:
t.Fatal("email was received, but sent attempt was not increased")
default:
break
}
}
func TestCreateSession_Twice_IncreaseSendAttempt(t *testing.T) {
is := is.New(t)
internalApi, _ := mustMakeInternalAPI(is)
mustCreateSession(is, internalApi)
resp := api.CreateSessionResponse{}
testReqBumped := *testReq
testReqBumped.SendAttempt = 1
err := internalApi.CreateSession(ctx, &testReqBumped, &resp)
is.NoErr(err)
is.Equal(len(resp.Sid), 43)
sub := <-mailer.c[api.Verification]
is.Equal(len(sub.Token), 64)
is.Equal(sub.To, testReq.ThreePid)
}
func TestValidateSession(t *testing.T) {
is := is.New(t)
internalApi, _ := mustMakeInternalAPI(is)
s, token := mustCreateSession(is, internalApi)
mustValidateSesson(is, internalApi, testReq.ClientSecret, token, s.Sid)
}
func TestIsSessionValidated_InvalidatedSession(t *testing.T) {
is := is.New(t)
internalApi, _ := mustMakeInternalAPI(is)
s, _ := mustCreateSession(is, internalApi)
resp := api.IsSessionValidatedResponse{}
err := internalApi.IsSessionValidated(ctx, &api.SessionOwnership{
Sid: s.Sid,
ClientSecret: testReq.ClientSecret,
}, &resp)
is.NoErr(err)
is.Equal(resp.Validated, false)
}
func TestIsSessionValidated_ValidatedSession(t *testing.T) {
is := is.New(t)
internalApi, _ := mustMakeInternalAPI(is)
s, token := mustCreateSession(is, internalApi)
resp := api.IsSessionValidatedResponse{}
mustValidateSesson(is, internalApi, testReq.ClientSecret, token, s.Sid)
err := internalApi.IsSessionValidated(ctx, &api.SessionOwnership{
Sid: s.Sid,
ClientSecret: testReq.ClientSecret,
}, &resp)
is.NoErr(err)
is.Equal(resp.Validated, true)
is.Equal(resp.ValidatedAt > 0, true)
}
func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateSessionResponse, token string) {
resp = &api.CreateSessionResponse{}
i.Mail = mailer
err := i.CreateSession(ctx, testReq, resp)
is.NoErr(err)
is.Equal(len(resp.Sid), 43)
sub := <-mailer.c[api.Verification]
is.Equal(len(sub.Token), 64)
is.Equal(sub.To, testReq.ThreePid)
token = sub.Token
return
}
func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token, sid string) {
err := i.ValidateSession(ctx, &api.ValidateSessionRequest{
SessionOwnership: api.SessionOwnership{
Sid: sid,
ClientSecret: secret,
},
Token: token,
},
struct{}{},
)
is.NoErr(err)
}