diff --git a/internal/random.go b/internal/random.go new file mode 100644 index 000000000..504197aea --- /dev/null +++ b/internal/random.go @@ -0,0 +1,15 @@ +package internal + +import ( + "crypto/rand" + "encoding/base64" +) + +func GenerateBlob(blobLen int) (string, error) { + b := make([]byte, blobLen) + if _, err := rand.Read(b); err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 79aaebc0b..06600b44c 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -555,6 +555,22 @@ func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOp return nil } +func (u *testUserAPI) CreateSession(context.Context, *userapi.CreateSessionRequest, *userapi.CreateSessionResponse) error { + return nil +} +func (u *testUserAPI) ValidateSession(context.Context, *userapi.ValidateSessionRequest, struct{}) error { + return nil +} +func (u *testUserAPI) GetThreePidForSession(context.Context, *userapi.SessionOwnership, *userapi.GetThreePidForSessionResponse) error { + return nil +} +func (u *testUserAPI) DeleteSession(context.Context, *userapi.SessionOwnership, struct{}) error { + return nil +} +func (u *testUserAPI) IsSessionValidated(context.Context, *userapi.SessionOwnership, *userapi.IsSessionValidatedResponse) error { + return nil +} + type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. // We'll override the functions we care about. diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go index 96160c10d..2287d8b2c 100644 --- a/setup/mscs/msc2946/msc2946_test.go +++ b/setup/mscs/msc2946/msc2946_test.go @@ -398,6 +398,22 @@ func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOp return nil } +func (u *testUserAPI) CreateSession(context.Context, *userapi.CreateSessionRequest, *userapi.CreateSessionResponse) error { + return nil +} +func (u *testUserAPI) ValidateSession(context.Context, *userapi.ValidateSessionRequest, struct{}) error { + return nil +} +func (u *testUserAPI) GetThreePidForSession(context.Context, *userapi.SessionOwnership, *userapi.GetThreePidForSessionResponse) error { + return nil +} +func (u *testUserAPI) DeleteSession(context.Context, *userapi.SessionOwnership, struct{}) error { + return nil +} +func (u *testUserAPI) IsSessionValidated(context.Context, *userapi.SessionOwnership, *userapi.IsSessionValidatedResponse) error { + return nil +} + type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. // We'll override the functions we care about. diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 21933c1c4..b8547775d 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -27,8 +27,10 @@ import ( keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/mail" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage/threepid" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -37,10 +39,12 @@ import ( type UserInternalAPI struct { AccountDB accounts.Database DeviceDB devices.Database + ThreePidDB threepid.Database ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService KeyAPI keyapi.KeyInternalAPI + Mail mail.Mailer } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { diff --git a/userapi/internal/threepid.go b/userapi/internal/threepid.go new file mode 100644 index 000000000..5b3ad3a9b --- /dev/null +++ b/userapi/internal/threepid.go @@ -0,0 +1,121 @@ +package internal + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/mail" + "github.com/matrix-org/dendrite/userapi/storage/threepid" +) + +const ( + sessionIdByteLength = 32 + tokenByteLength = 48 +) + +var ErrBadSession = errors.New("provided sid, client_secret and token does not point to valid session") + +func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSessionRequest, res *api.CreateSessionResponse) error { + s, err := a.ThreePidDB.GetSessionByThreePidAndSecret(ctx, req.ThreePid, req.ClientSecret) + if err != nil { + if err == sql.ErrNoRows { + token, err := internal.GenerateBlob(tokenByteLength) + if err != nil { + return err + } + sid, err := internal.GenerateBlob(sessionIdByteLength) + if err != nil { + return err + } + s = &api.Session{ + Sid: sid, + ClientSecret: req.ClientSecret, + ThreePid: req.ThreePid, + SendAttempt: req.SendAttempt, + Token: token, + NextLink: req.NextLink, + } + err = a.ThreePidDB.InsertSession(ctx, s) + if err != nil { + return err + } + } else { + return err + } + } else { + if req.SendAttempt > s.SendAttempt { + err = a.ThreePidDB.UpdateSendAttemptNextLink(ctx, s.Sid, req.NextLink) + if err != nil { + return err + } + } else { + res.Sid = s.Sid + return nil + } + } + res.Sid = s.Sid + // TODO - if we fail sending email, send_attempt for next requests must be bumped, + // otherwise we will just return nil from this function and not sent email + return a.Mail.Send(&mail.Mail{ + To: s.ThreePid, + Link: s.NextLink, + Token: s.Token, + Extra: req.Extra, + }, req.SessionType) +} + +func (a *UserInternalAPI) ValidateSession(ctx context.Context, req *api.ValidateSessionRequest, res struct{}) error { + s, err := getSessionByOwnership(ctx, &req.SessionOwnership, a.ThreePidDB) + if err != nil { + return err + } + if s.Token != req.Token { + return ErrBadSession + } + return a.ThreePidDB.ValidateSession(ctx, s.Sid, int(time.Now().Unix())) +} + +func (a *UserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error { + s, err := getSessionByOwnership(ctx, req, a.ThreePidDB) + if err != nil { + return err + } + res.ThreePid = s.ThreePid + return nil +} + +func (a *UserInternalAPI) DeleteSession(ctx context.Context, req *api.SessionOwnership, res struct{}) error { + s, err := getSessionByOwnership(ctx, req, a.ThreePidDB) + if err != nil { + return err + } + return a.ThreePidDB.DeleteSession(ctx, s.Sid) +} + +func (a *UserInternalAPI) IsSessionValidated(ctx context.Context, req *api.SessionOwnership, res *api.IsSessionValidatedResponse) error { + s, err := getSessionByOwnership(ctx, req, a.ThreePidDB) + if err != nil { + return err + } + res.Validated = s.Validated + res.ValidatedAt = s.ValidatedAt + return nil +} + +func getSessionByOwnership(ctx context.Context, so *api.SessionOwnership, d threepid.Database) (*api.Session, error) { + s, err := d.GetSession(ctx, so.Sid) + if err != nil { + if err == sql.ErrNoRows { + return nil, ErrBadSession + } + return nil, err + } + if s.ClientSecret != so.ClientSecret { + return nil, ErrBadSession + } + return s, err +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 1cb5ef0a8..25103f245 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -44,6 +44,12 @@ const ( QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QuerySearchProfilesPath = "/userapi/querySearchProfiles" QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" + + CreateSessionPath = "/userapi/createSession" + ValidateSessionPath = "/userapi/validateSession" + GetThreePidForSessionPath = "/userapi/getThreePidForSession" + DeleteSessionPath = "/userapi/deleteSession" + IsSessionValidatedPath = "/userapi/isSessionValidated" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -225,3 +231,43 @@ func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.Que apiURL := h.apiURL + QueryOpenIDTokenPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSessionRequest, res *api.CreateSessionResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "CreateSession") + defer span.Finish() + + apiURL := h.apiURL + CreateSessionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) ValidateSession(ctx context.Context, req *api.ValidateSessionRequest, res struct{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "ValidateSession") + defer span.Finish() + + apiURL := h.apiURL + ValidateSessionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetThreePidForSession") + defer span.Finish() + + apiURL := h.apiURL + GetThreePidForSessionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) DeleteSession(ctx context.Context, req *api.SessionOwnership, res struct{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "DeleteSession") + defer span.Finish() + + apiURL := h.apiURL + DeleteSessionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) IsSessionValidated(ctx context.Context, req *api.SessionOwnership, res *api.IsSessionValidatedResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "IsSessionValidated") + defer span.Finish() + + apiURL := h.apiURL + IsSessionValidatedPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/userapi.go b/userapi/userapi.go index 74702020a..6d1479497 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -21,8 +21,10 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/mail" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage/threepid" "github.com/sirupsen/logrus" ) @@ -36,17 +38,26 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, -) api.UserInternalAPI { +) *internal.UserInternalAPI { deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } - + threepidDb, err := threepid.NewDatabase(&cfg.ThreepidDatabase) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to threepid db") + } + mailer, err := mail.NewMailer(cfg) + if err != nil { + logrus.WithError(err).Panicf("failed to crate Mailer") + } return &internal.UserInternalAPI{ AccountDB: accountDB, DeviceDB: deviceDB, + ThreePidDB: threepidDb, ServerName: cfg.Matrix.ServerName, AppServices: appServices, KeyAPI: keyAPI, + Mail: mailer, } } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 0141258e6..70c7868fe 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -13,9 +13,12 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/mail" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" + "github.com/matryer/is" "golang.org/x/crypto/bcrypt" ) @@ -23,13 +26,37 @@ const ( serverName = gomatrixserverlib.ServerName("example.com") ) -func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { +var ( + testReq = &api.CreateSessionRequest{ + ClientSecret: "foobar", + NextLink: "http://foobar.com", + ThreePid: "foo@bar.com", + Extra: []string{}, + SendAttempt: 0, + } + ctx = context.Background() + mailer = &testMailer{ + c: map[api.ThreepidSessionType]chan *mail.Mail{ + api.Password: make(chan *mail.Mail, 3), + api.Verification: make(chan *mail.Mail, 3), + }, + } +) + +type testMailer struct { + c map[api.ThreepidSessionType]chan *mail.Mail +} + +func (tm *testMailer) Send(s *mail.Mail, t api.ThreepidSessionType) error { + tm.c[t] <- s + return nil +} + +func mustMakeInternalAPI(is *is.I) (*internal.UserInternalAPI, accounts.Database) { accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ ConnectionString: "file::memory:", }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) - if err != nil { - t.Fatalf("failed to create account DB: %s", err) - } + is.NoErr(err) cfg := &config.UserAPI{ DeviceDatabase: config.DatabaseOptions{ ConnectionString: "file::memory:", @@ -39,15 +66,24 @@ func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) Matrix: &config.Global{ ServerName: serverName, }, + ThreepidDatabase: config.DatabaseOptions{ + ConnectionString: "file::memory:", + MaxOpenConnections: 1, + MaxIdleConnections: 1, + }, + Email: config.EmailConf{ + TemplatesPath: "../res/default", + }, } return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { + is := is.New(t) aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - userAPI, accountDB := MustMakeInternalAPI(t) + userAPI, accountDB := mustMakeInternalAPI(is) _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") if err != nil { t.Fatalf("failed to make account: %s", err) @@ -119,3 +155,101 @@ func TestQueryProfile(t *testing.T) { runCases(userAPI) }) } + +func TestCreateSession(t *testing.T) { + is := is.New(t) + internalApi, _ := mustMakeInternalAPI(is) + mustCreateSession(is, internalApi) +} + +func TestCreateSession_Twice(t *testing.T) { + is := is.New(t) + internalApi, _ := mustMakeInternalAPI(is) + mustCreateSession(is, internalApi) + resp := api.CreateSessionResponse{} + err := internalApi.CreateSession(ctx, testReq, &resp) + is.NoErr(err) + is.Equal(len(resp.Sid), 43) + select { + case <-mailer.c[api.Verification]: + t.Fatal("email was received, but sent attempt was not increased") + default: + break + } +} + +func TestCreateSession_Twice_IncreaseSendAttempt(t *testing.T) { + is := is.New(t) + internalApi, _ := mustMakeInternalAPI(is) + mustCreateSession(is, internalApi) + resp := api.CreateSessionResponse{} + testReqBumped := *testReq + testReqBumped.SendAttempt = 1 + err := internalApi.CreateSession(ctx, &testReqBumped, &resp) + is.NoErr(err) + is.Equal(len(resp.Sid), 43) + sub := <-mailer.c[api.Verification] + is.Equal(len(sub.Token), 64) + is.Equal(sub.To, testReq.ThreePid) +} + +func TestValidateSession(t *testing.T) { + is := is.New(t) + internalApi, _ := mustMakeInternalAPI(is) + s, token := mustCreateSession(is, internalApi) + mustValidateSesson(is, internalApi, testReq.ClientSecret, token, s.Sid) +} + +func TestIsSessionValidated_InvalidatedSession(t *testing.T) { + is := is.New(t) + internalApi, _ := mustMakeInternalAPI(is) + s, _ := mustCreateSession(is, internalApi) + resp := api.IsSessionValidatedResponse{} + err := internalApi.IsSessionValidated(ctx, &api.SessionOwnership{ + Sid: s.Sid, + ClientSecret: testReq.ClientSecret, + }, &resp) + is.NoErr(err) + is.Equal(resp.Validated, false) +} + +func TestIsSessionValidated_ValidatedSession(t *testing.T) { + is := is.New(t) + internalApi, _ := mustMakeInternalAPI(is) + s, token := mustCreateSession(is, internalApi) + resp := api.IsSessionValidatedResponse{} + mustValidateSesson(is, internalApi, testReq.ClientSecret, token, s.Sid) + err := internalApi.IsSessionValidated(ctx, &api.SessionOwnership{ + Sid: s.Sid, + ClientSecret: testReq.ClientSecret, + }, &resp) + is.NoErr(err) + is.Equal(resp.Validated, true) + is.Equal(resp.ValidatedAt > 0, true) +} + +func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateSessionResponse, token string) { + resp = &api.CreateSessionResponse{} + i.Mail = mailer + err := i.CreateSession(ctx, testReq, resp) + is.NoErr(err) + is.Equal(len(resp.Sid), 43) + sub := <-mailer.c[api.Verification] + is.Equal(len(sub.Token), 64) + is.Equal(sub.To, testReq.ThreePid) + token = sub.Token + return +} + +func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token, sid string) { + err := i.ValidateSession(ctx, &api.ValidateSessionRequest{ + SessionOwnership: api.SessionOwnership{ + Sid: sid, + ClientSecret: secret, + }, + Token: token, + }, + struct{}{}, + ) + is.NoErr(err) +}