Support for m.login.token (#2014)

* Add GOPATH to PATH in find-lint.sh.

The user doesn't necessarily have it in PATH.

* Refactor LoginTypePassword and Type to support m.login.token and m.login.sso.

For login token:

* m.login.token will require deleting the token after completeAuth has
  generated an access token, so a cleanup function is returned by
  Type.Login.
* Allowing different login types will require parsing the /login body
  twice: first to extract the "type" and then the type-specific parsing.
  Thus, we will have to buffer the request JSON in /login, like
  UserInteractive already does.

For SSO:

* NewUserInteractive will have to also use GetAccountByLocalpart. It
  makes more sense to just pass a (narrowed-down) accountDB interface
  to it than adding more function pointers.

Code quality:

* Passing around (and down-casting) interface{} for login request types
  has drawbacks in terms of type-safety, and no inherent benefits. We
  always decode JSON anyway. Hence renaming to Type.LoginFromJSON. Code
  that directly uses LoginTypePassword with parsed data can still use
  Login.
* Removed a TODO for SSO. This is already tracked in #1297.
* httputil.UnmarshalJSON is useful because it returns a JSONResponse.

This change is intended to have no functional changes.

* Support login tokens in User API.

This adds full lifecycle functions for login tokens: create, query, delete.

* Support m.login.token in /login.

* Fixes for PR review.

* Set @matrix-org/dendrite-core as repository code owner

* Return event NID from `StoreEvent`, match PSQL vs SQLite behaviour, tweak backfill persistence (#2071)

Co-authored-by: kegsay <kegan@matrix.org>
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
tommie 2022-02-10 11:27:26 +01:00 committed by GitHub
parent 432c35a307
commit c36e4546c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 1244 additions and 80 deletions

View file

@ -33,7 +33,7 @@ echo "Looking for lint..."
# Capture exit code to ensure go.{mod,sum} is restored before exiting # Capture exit code to ensure go.{mod,sum} is restored before exiting
exit_code=0 exit_code=0
golangci-lint run $args || exit_code=1 PATH="$PATH:${GOPATH:-~/go}/bin" golangci-lint run $args || exit_code=1
# Restore go.{mod,sum} # Restore go.{mod,sum}
mv go.mod.bak go.mod && mv go.sum.bak go.sum mv go.mod.bak go.mod && mv go.sum.bak go.sum

View file

@ -42,6 +42,7 @@ type DeviceDatabase interface {
type AccountDatabase interface { type AccountDatabase interface {
// Look up the account matching the given localpart. // Look up the account matching the given localpart.
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error)
} }
// VerifyUserFromRequest authenticates the HTTP request, // VerifyUserFromRequest authenticates the HTTP request,

View file

@ -10,4 +10,5 @@ const (
LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeSharedSecret = "org.matrix.login.shared_secret"
LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeRecaptcha = "m.login.recaptcha"
LoginTypeApplicationService = "m.login.application_service" LoginTypeApplicationService = "m.login.application_service"
LoginTypeToken = "m.login.token"
) )

83
clientapi/auth/login.go Normal file
View file

@ -0,0 +1,83 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"context"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
// LoginFromJSONReader performs authentication given a login request body reader and
// some context. It returns the basic login information and a cleanup function to be
// called after authorization has completed, with the result of the authorization.
// If the final return value is non-nil, an error occurred and the cleanup function
// is nil.
func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountDatabase, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) {
reqBytes, err := ioutil.ReadAll(r)
if err != nil {
err := &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()),
}
return nil, nil, err
}
var header struct {
Type string `json:"type"`
}
if err := json.Unmarshal(reqBytes, &header); err != nil {
err := &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()),
}
return nil, nil, err
}
var typ Type
switch header.Type {
case authtypes.LoginTypePassword:
typ = &LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg,
}
case authtypes.LoginTypeToken:
typ = &LoginTypeToken{
UserAPI: userAPI,
Config: cfg,
}
default:
err := util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("unhandled login type: " + header.Type),
}
return nil, nil, &err
}
return typ.LoginFromJSON(ctx, reqBytes)
}
// UserInternalAPIForLogin contains the aspects of UserAPI required for logging in.
type UserInternalAPIForLogin interface {
uapi.LoginTokenInternalAPI
}

View file

@ -0,0 +1,194 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"context"
"database/sql"
"net/http"
"reflect"
"strings"
"testing"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
func TestLoginFromJSONReader(t *testing.T) {
ctx := context.Background()
tsts := []struct {
Name string
Body string
WantUsername string
WantDeviceID string
WantDeletedTokens []string
}{
{
Name: "passwordWorks",
Body: `{
"type": "m.login.password",
"identifier": { "type": "m.id.user", "user": "alice" },
"password": "herpassword",
"device_id": "adevice"
}`,
WantUsername: "alice",
WantDeviceID: "adevice",
},
{
Name: "tokenWorks",
Body: `{
"type": "m.login.token",
"token": "atoken",
"device_id": "adevice"
}`,
WantUsername: "@auser:example.com",
WantDeviceID: "adevice",
WantDeletedTokens: []string{"atoken"},
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
var accountDB fakeAccountDB
var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{
Matrix: &config.Global{
ServerName: serverName,
},
}
login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg)
if err != nil {
t.Fatalf("LoginFromJSONReader failed: %+v", err)
}
cleanup(ctx, &util.JSONResponse{Code: http.StatusOK})
if login.Username() != tst.WantUsername {
t.Errorf("Username: got %q, want %q", login.Username(), tst.WantUsername)
}
if login.DeviceID == nil {
if tst.WantDeviceID != "" {
t.Errorf("DeviceID: got %v, want %q", login.DeviceID, tst.WantDeviceID)
}
} else {
if *login.DeviceID != tst.WantDeviceID {
t.Errorf("DeviceID: got %q, want %q", *login.DeviceID, tst.WantDeviceID)
}
}
if !reflect.DeepEqual(userAPI.DeletedTokens, tst.WantDeletedTokens) {
t.Errorf("DeletedTokens: got %+v, want %+v", userAPI.DeletedTokens, tst.WantDeletedTokens)
}
})
}
}
func TestBadLoginFromJSONReader(t *testing.T) {
ctx := context.Background()
tsts := []struct {
Name string
Body string
WantErrCode string
}{
{Name: "empty", WantErrCode: "M_BAD_JSON"},
{
Name: "badUnmarshal",
Body: `badsyntaxJSON`,
WantErrCode: "M_BAD_JSON",
},
{
Name: "badPassword",
Body: `{
"type": "m.login.password",
"identifier": { "type": "m.id.user", "user": "alice" },
"password": "invalidpassword",
"device_id": "adevice"
}`,
WantErrCode: "M_FORBIDDEN",
},
{
Name: "badToken",
Body: `{
"type": "m.login.token",
"token": "invalidtoken",
"device_id": "adevice"
}`,
WantErrCode: "M_FORBIDDEN",
},
{
Name: "badType",
Body: `{
"type": "m.login.invalid",
"device_id": "adevice"
}`,
WantErrCode: "M_INVALID_ARGUMENT_VALUE",
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
var accountDB fakeAccountDB
var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{
Matrix: &config.Global{
ServerName: serverName,
},
}
_, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg)
if errRes == nil {
cleanup(ctx, nil)
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode)
} else if merr, ok := errRes.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != tst.WantErrCode {
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode)
}
})
}
}
type fakeAccountDB struct {
AccountDatabase
}
func (*fakeAccountDB) GetAccountByPassword(ctx context.Context, localpart, password string) (*uapi.Account, error) {
if password == "invalidpassword" {
return nil, sql.ErrNoRows
}
return &uapi.Account{}, nil
}
type fakeUserInternalAPI struct {
UserInternalAPIForLogin
DeletedTokens []string
}
func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error {
ua.DeletedTokens = append(ua.DeletedTokens, req.Token)
return nil
}
func (*fakeUserInternalAPI) QueryLoginToken(ctx context.Context, req *uapi.QueryLoginTokenRequest, res *uapi.QueryLoginTokenResponse) error {
if req.Token == "invalidtoken" {
return nil
}
res.Data = &uapi.LoginTokenData{UserID: "@auser:example.com"}
return nil
}

View file

@ -0,0 +1,83 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"context"
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
// LoginTypeToken describes how to authenticate with a login token.
type LoginTypeToken struct {
UserAPI uapi.LoginTokenInternalAPI
Config *config.ClientAPI
}
// Name implements Type.
func (t *LoginTypeToken) Name() string {
return authtypes.LoginTypeToken
}
// LoginFromJSON implements Type. The cleanup function deletes the token from
// the database on success.
func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) {
var r loginTokenRequest
if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil {
return nil, nil, err
}
var res uapi.QueryLoginTokenResponse
if err := t.UserAPI.QueryLoginToken(ctx, &uapi.QueryLoginTokenRequest{Token: r.Token}, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("UserAPI.QueryLoginToken failed")
jsonErr := jsonerror.InternalServerError()
return nil, nil, &jsonErr
}
if res.Data == nil {
return nil, nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("invalid login token"),
}
}
r.Login.Identifier.Type = "m.id.user"
r.Login.Identifier.User = res.Data.UserID
cleanup := func(ctx context.Context, authRes *util.JSONResponse) {
if authRes == nil {
util.GetLogger(ctx).Error("No JSONResponse provided to LoginTokenType cleanup function")
return
}
if authRes.Code == http.StatusOK {
var res uapi.PerformLoginTokenDeletionResponse
if err := t.UserAPI.PerformLoginTokenDeletion(ctx, &uapi.PerformLoginTokenDeletionRequest{Token: r.Token}, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("UserAPI.PerformLoginTokenDeletion failed")
}
}
}
return &r.Login, cleanup, nil
}
// loginTokenRequest struct to hold the possible parameters from an HTTP request.
type loginTokenRequest struct {
Login
Token string `json:"token"`
}

View file

@ -20,6 +20,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -41,16 +43,26 @@ type LoginTypePassword struct {
} }
func (t *LoginTypePassword) Name() string { func (t *LoginTypePassword) Name() string {
return "m.login.password" return authtypes.LoginTypePassword
} }
func (t *LoginTypePassword) Request() interface{} { func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) {
return &PasswordRequest{} var r PasswordRequest
if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil {
return nil, nil, err
}
login, err := t.Login(ctx, &r)
if err != nil {
return nil, nil, err
}
return login, func(context.Context, *util.JSONResponse) {}, nil
} }
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
r := req.(*PasswordRequest) r := req.(*PasswordRequest)
username := r.Username() username := strings.ToLower(r.Username())
if username == "" { if username == "" {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,

View file

@ -32,22 +32,24 @@ import (
type Type interface { type Type interface {
// Name returns the name of the auth type e.g `m.login.password` // Name returns the name of the auth type e.g `m.login.password`
Name() string Name() string
// Request returns a pointer to a new request body struct to unmarshal into.
Request() interface{}
// Login with the auth type, returning an error response on failure. // Login with the auth type, returning an error response on failure.
// Not all types support login, only m.login.password and m.login.token // Not all types support login, only m.login.password and m.login.token
// See https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login // See https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login
// `req` is guaranteed to be the type returned from Request()
// This function will be called when doing login and when doing 'sudo' style // This function will be called when doing login and when doing 'sudo' style
// actions e.g deleting devices. The response must be a 401 as per: // actions e.g deleting devices. The response must be a 401 as per:
// "If the homeserver decides that an attempt on a stage was unsuccessful, but the // "If the homeserver decides that an attempt on a stage was unsuccessful, but the
// client may make a second attempt, it returns the same HTTP status 401 response as above, // client may make a second attempt, it returns the same HTTP status 401 response as above,
// with the addition of the standard errcode and error fields describing the error." // with the addition of the standard errcode and error fields describing the error."
Login(ctx context.Context, req interface{}) (login *Login, errRes *util.JSONResponse) //
// The returned cleanup function must be non-nil on success, and will be called after
// authorization has been completed. Its argument is the final result of authorization.
LoginFromJSON(ctx context.Context, reqBytes []byte) (login *Login, cleanup LoginCleanupFunc, errRes *util.JSONResponse)
// TODO: Extend to support Register() flow // TODO: Extend to support Register() flow
// Register(ctx context.Context, sessionID string, req interface{}) // Register(ctx context.Context, sessionID string, req interface{})
} }
type LoginCleanupFunc func(context.Context, *util.JSONResponse)
// LoginIdentifier represents identifier types // LoginIdentifier represents identifier types
// https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types // https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
type LoginIdentifier struct { type LoginIdentifier struct {
@ -61,11 +63,8 @@ type LoginIdentifier struct {
// Login represents the shared fields used in all forms of login/sudo endpoints. // Login represents the shared fields used in all forms of login/sudo endpoints.
type Login struct { type Login struct {
Type string `json:"type"` LoginIdentifier // Flat fields deprecated in favour of `identifier`.
Identifier LoginIdentifier `json:"identifier"` Identifier LoginIdentifier `json:"identifier"`
User string `json:"user"` // deprecated in favour of identifier
Medium string `json:"medium"` // deprecated in favour of identifier
Address string `json:"address"` // deprecated in favour of identifier
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
// Thus a pointer is needed to differentiate between the two // Thus a pointer is needed to differentiate between the two
@ -111,12 +110,11 @@ type UserInteractive struct {
Sessions map[string][]string Sessions map[string][]string
} }
func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI) *UserInteractive { func NewUserInteractive(accountDB AccountDatabase, cfg *config.ClientAPI) *UserInteractive {
typePassword := &LoginTypePassword{ typePassword := &LoginTypePassword{
GetAccountByPassword: getAccByPass, GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg, Config: cfg,
} }
// TODO: Add SSO login
return &UserInteractive{ return &UserInteractive{
Completed: []string{}, Completed: []string{},
Flows: []userInteractiveFlow{ Flows: []userInteractiveFlow{
@ -236,18 +234,13 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *
} }
} }
r := loginType.Request() login, cleanup, resErr := loginType.LoginFromJSON(ctx, []byte(gjson.GetBytes(bodyBytes, "auth").Raw))
if err := json.Unmarshal([]byte(gjson.GetBytes(bodyBytes, "auth").Raw), r); err != nil { if resErr != nil {
return nil, &util.JSONResponse{ return nil, u.ResponseWithChallenge(sessionID, resErr.JSON)
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
}
} }
login, resErr := loginType.Login(ctx, r)
if resErr == nil { u.AddCompletedStage(sessionID, authType)
u.AddCompletedStage(sessionID, authType) cleanup(ctx, nil)
// TODO: Check if there's more stages to go and return an error // TODO: Check if there's more stages to go and return an error
return login, nil return login, nil
}
return nil, u.ResponseWithChallenge(sessionID, resErr.JSON)
} }

View file

@ -24,7 +24,11 @@ var (
} }
) )
func getAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { type fakeAccountDatabase struct {
AccountDatabase
}
func (*fakeAccountDatabase) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) {
acc, ok := lookup[localpart+" "+plaintextPassword] acc, ok := lookup[localpart+" "+plaintextPassword]
if !ok { if !ok {
return nil, fmt.Errorf("unknown user/password") return nil, fmt.Errorf("unknown user/password")
@ -38,7 +42,7 @@ func setup() *UserInteractive {
ServerName: serverName, ServerName: serverName,
}, },
} }
return NewUserInteractive(getAccountByPassword, cfg) return NewUserInteractive(&fakeAccountDatabase{}, cfg)
} }
func TestUserInteractiveChallenge(t *testing.T) { func TestUserInteractiveChallenge(t *testing.T) {

View file

@ -36,6 +36,10 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon
return &resp return &resp
} }
return UnmarshalJSON(body, iface)
}
func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse {
if !utf8.Valid(body) { if !utf8.Valid(body) {
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,

View file

@ -19,7 +19,6 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -65,21 +64,14 @@ func Login(
JSON: passwordLogin(), JSON: passwordLogin(),
} }
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
typePassword := auth.LoginTypePassword{ login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, accountDB, userAPI, cfg)
GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg,
}
r := typePassword.Request()
resErr := httputil.UnmarshalJSONRequest(req, r)
if resErr != nil {
return *resErr
}
login, authErr := typePassword.Login(req.Context(), r)
if authErr != nil { if authErr != nil {
return *authErr return *authErr
} }
// make a device/access token // make a device/access token
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) authErr2 := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
cleanup(req.Context(), &authErr2)
return authErr2
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusMethodNotAllowed, Code: http.StatusMethodNotAllowed,

View file

@ -62,7 +62,7 @@ func Setup(
mscCfg *config.MSCs, mscCfg *config.MSCs,
) { ) {
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg) userInteractiveAuth := auth.NewUserInteractive(accountDB, cfg)
unstableFeatures := map[string]bool{ unstableFeatures := map[string]bool{
"org.matrix.e2e_cross_signing": true, "org.matrix.e2e_cross_signing": true,

View file

@ -24,6 +24,8 @@ import (
// UserInternalAPI is the internal API for information about users and devices. // UserInternalAPI is the internal API for information about users and devices.
type UserInternalAPI interface { type UserInternalAPI interface {
LoginTokenInternalAPI
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error

View file

@ -0,0 +1,69 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"time"
)
type LoginTokenInternalAPI interface {
// PerformLoginTokenCreation creates a new login token and associates it with the provided data.
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error
// PerformLoginTokenDeletion ensures the token doesn't exist. Success
// is returned even if the token didn't exist, or was already deleted.
PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error
// QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil.
QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error
}
// LoginTokenData is the data that can be retrieved given a login token. This is
// provided by the calling code.
type LoginTokenData struct {
// UserID is the full mxid of the user.
UserID string
}
// LoginTokenMetadata contains metadata created and maintained by the User API.
type LoginTokenMetadata struct {
Token string
Expiration time.Time
}
type PerformLoginTokenCreationRequest struct {
Data LoginTokenData
}
type PerformLoginTokenCreationResponse struct {
Metadata LoginTokenMetadata
}
type PerformLoginTokenDeletionRequest struct {
Token string
}
type PerformLoginTokenDeletionResponse struct{}
type QueryLoginTokenRequest struct {
Token string
}
type QueryLoginTokenResponse struct {
// Data is nil if the token was invalid.
Data *LoginTokenData
}

View file

@ -0,0 +1,39 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"github.com/matrix-org/util"
)
func (t *UserInternalAPITrace) PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error {
err := t.Impl.PerformLoginTokenCreation(ctx, req, res)
util.GetLogger(ctx).Infof("PerformLoginTokenCreation req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error {
err := t.Impl.PerformLoginTokenDeletion(ctx, req, res)
util.GetLogger(ctx).Infof("PerformLoginTokenDeletion req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error {
err := t.Impl.QueryLoginToken(ctx, req, res)
util.GetLogger(ctx).Infof("QueryLoginToken req=%+v res=%+v", js(req), js(res))
return err
}

View file

@ -0,0 +1,78 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// PerformLoginTokenCreation creates a new login token and associates it with the provided data.
func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *api.PerformLoginTokenCreationRequest, res *api.PerformLoginTokenCreationResponse) error {
util.GetLogger(ctx).WithField("user_id", req.Data.UserID).Info("PerformLoginTokenCreation")
_, domain, err := gomatrixserverlib.SplitID('@', req.Data.UserID)
if err != nil {
return err
}
if domain != a.ServerName {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
}
tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data)
if err != nil {
return err
}
res.Metadata = *tokenMeta
return nil
}
// PerformLoginTokenDeletion ensures the token doesn't exist.
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
return a.DeviceDB.RemoveLoginToken(ctx, req.Token)
}
// QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil.
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token)
if err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil
}
return err
}
localpart, domain, err := gomatrixserverlib.SplitID('@', tokenData.UserID)
if err != nil {
return err
}
if domain != a.ServerName {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
}
if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil
}
return err
}
res.Data = tokenData
return nil
}

View file

@ -0,0 +1,65 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package inthttp
import (
"context"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
const (
PerformLoginTokenCreationPath = "/userapi/performLoginTokenCreation"
PerformLoginTokenDeletionPath = "/userapi/performLoginTokenDeletion"
QueryLoginTokenPath = "/userapi/queryLoginToken"
)
func (h *httpUserInternalAPI) PerformLoginTokenCreation(
ctx context.Context,
request *api.PerformLoginTokenCreationRequest,
response *api.PerformLoginTokenCreationResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation")
defer span.Finish()
apiURL := h.apiURL + PerformLoginTokenCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformLoginTokenDeletion(
ctx context.Context,
request *api.PerformLoginTokenDeletionRequest,
response *api.PerformLoginTokenDeletionResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformLoginTokenDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryLoginToken(
ctx context.Context,
request *api.QueryLoginTokenRequest,
response *api.QueryLoginTokenResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken")
defer span.Finish()
apiURL := h.apiURL + QueryLoginTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}

View file

@ -27,6 +27,8 @@ import (
// nolint: gocyclo // nolint: gocyclo
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
addRoutesLoginToken(internalAPIMux, s)
internalAPIMux.Handle(PerformAccountCreationPath, internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformAccountCreationRequest{} request := api.PerformAccountCreationRequest{}

View file

@ -0,0 +1,68 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
// addRoutesLoginToken adds routes for all login token API calls.
func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) {
internalAPIMux.Handle(PerformLoginTokenCreationPath,
httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformLoginTokenCreationRequest{}
response := api.PerformLoginTokenCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformLoginTokenDeletionPath,
httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformLoginTokenDeletionRequest{}
response := api.PerformLoginTokenDeletionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryLoginTokenPath,
httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse {
request := api.QueryLoginTokenRequest{}
response := api.QueryLoginTokenResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View file

@ -38,4 +38,15 @@ type Database interface {
RemoveDevices(ctx context.Context, localpart string, devices []string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted. // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
} }

View file

@ -0,0 +1,93 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt
}
// execSchema ensures tables and indices exist.
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- When the token expires
token_expires_at TIMESTAMP NOT NULL,
-- The mxid for this account
user_id TEXT NOT NULL
);
-- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`)
return err
}
// prepare runs statement preparation.
func (s *loginTokenStatements) prepare(db *sql.DB) error {
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
}.Prepare(db)
}
// insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err
}
// deleteByToken removes the named token.
//
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil {
return err
}
if n, err := res.RowsAffected(); err == nil && n > 1 {
util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1)
}
return nil
}
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil {
return nil, err
}
return &data, nil
}

View file

@ -19,6 +19,7 @@ import (
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -27,28 +28,38 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// The length of generated device IDs const (
var deviceIDByteLength = 6 // The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database. // Database represents a device database.
type Database struct { type Database struct {
db *sql.DB db *sql.DB
devices devicesStatements devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
} }
// NewDatabase creates a new device database // NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
} }
d := devicesStatements{} var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing, // Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns // and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil { if err = d.execSchema(db); err != nil {
return nil, err return nil, err
} }
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m) deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
@ -58,8 +69,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.prepare(db, serverName); err != nil { if err = d.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, d}, nil return &Database{db, d, lt, loginTokenLifetime}, nil
} }
// GetDeviceByAccessToken returns the device matching the given access token. // GetDeviceByAccessToken returns the device matching the given access token.
@ -210,3 +224,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
}) })
} }
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -0,0 +1,93 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt
}
// execSchema ensures tables and indices exist.
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- When the token expires
token_expires_at TIMESTAMP NOT NULL,
-- The mxid for this account
user_id TEXT NOT NULL
);
-- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`)
return err
}
// prepare runs statement preparation.
func (s *loginTokenStatements) prepare(db *sql.DB) error {
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
}.Prepare(db)
}
// insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err
}
// deleteByToken removes the named token.
//
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil {
return err
}
if n, err := res.RowsAffected(); err == nil && n > 1 {
util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1)
}
return nil
}
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil {
return nil, err
}
return &data, nil
}

View file

@ -19,6 +19,7 @@ import (
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -27,30 +28,41 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// The length of generated device IDs const (
var deviceIDByteLength = 6 // The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database. // Database represents a device database.
type Database struct { type Database struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer writer sqlutil.Writer
devices devicesStatements devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
} }
// NewDatabase creates a new device database // NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
} }
writer := sqlutil.NewExclusiveWriter() writer := sqlutil.NewExclusiveWriter()
d := devicesStatements{} var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing, // Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns // and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil { if err = d.execSchema(db); err != nil {
return nil, err return nil, err
} }
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m) deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
@ -59,7 +71,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.prepare(db, writer, serverName); err != nil { if err = d.prepare(db, writer, serverName); err != nil {
return nil, err return nil, err
} }
return &Database{db, writer, d}, nil if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, writer, d, lt, loginTokenLifetime}, nil
} }
// GetDeviceByAccessToken returns the device matching the given access token. // GetDeviceByAccessToken returns the device matching the given access token.
@ -210,3 +225,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
}) })
} }
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -19,6 +19,7 @@ package devices
import ( import (
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres" "github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
@ -27,13 +28,14 @@ import (
) )
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters // and sets postgres connection parameters. loginTokenLifetime determines how long a
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) { // login token from CreateLoginToken is valid.
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName) return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName) return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
default: default:
return nil, fmt.Errorf("unexpected database type") return nil, fmt.Errorf("unexpected database type")
} }

View file

@ -16,6 +16,7 @@ package devices
import ( import (
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
@ -25,10 +26,11 @@ import (
func NewDatabase( func NewDatabase(
dbProperties *config.DatabaseOptions, dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
loginTokenLifetime time.Duration,
) (Database, error) { ) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName) return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation") return nil, fmt.Errorf("can't use Postgres implementation")
default: default:

View file

@ -15,6 +15,8 @@
package userapi package userapi
import ( import (
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -26,6 +28,13 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// defaultLoginTokenLifetime determines how old a valid token may be.
//
// NOTSPEC: The current spec says "SHOULD be limited to around five
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
const defaultLoginTokenLifetime = 2 * time.Minute
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
// on the given input API. // on the given input API.
func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
@ -37,11 +46,21 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
func NewInternalAPI( func NewInternalAPI(
accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI { ) api.UserInternalAPI {
deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName) deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db") logrus.WithError(err).Panicf("failed to connect to device db")
} }
return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI)
}
func newInternalAPI(
accountDB accounts.Database,
deviceDB devices.Database,
cfg *config.UserAPI,
appServices []config.ApplicationService,
keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI {
return &internal.UserInternalAPI{ return &internal.UserInternalAPI{
AccountDB: accountDB, AccountDB: accountDB,
DeviceDB: deviceDB, DeviceDB: deviceDB,

View file

@ -1,4 +1,18 @@
package userapi_test // Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package userapi
import ( import (
"context" "context"
@ -6,15 +20,16 @@ import (
"net/http" "net/http"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -23,31 +38,41 @@ const (
serverName = gomatrixserverlib.ServerName("example.com") serverName = gomatrixserverlib.ServerName("example.com")
) )
func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { type apiTestOpts struct {
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ loginTokenLifetime time.Duration
ConnectionString: "file::memory:", }
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS)
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) {
if opts.loginTokenLifetime == 0 {
opts.loginTokenLifetime = defaultLoginTokenLifetime
}
dbopts := &config.DatabaseOptions{
ConnectionString: "file::memory:",
MaxOpenConnections: 1,
MaxIdleConnections: 1,
}
accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS)
if err != nil { if err != nil {
t.Fatalf("failed to create account DB: %s", err) t.Fatalf("failed to create account DB: %s", err)
} }
deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime)
if err != nil {
t.Fatalf("failed to create device DB: %s", err)
}
cfg := &config.UserAPI{ cfg := &config.UserAPI{
DeviceDatabase: config.DatabaseOptions{
ConnectionString: "file::memory:",
MaxOpenConnections: 1,
MaxIdleConnections: 1,
},
Matrix: &config.Global{ Matrix: &config.Global{
ServerName: serverName, ServerName: serverName,
}, },
} }
return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB
} }
func TestQueryProfile(t *testing.T) { func TestQueryProfile(t *testing.T) {
aliceAvatarURL := "mxc://example.com/alice" aliceAvatarURL := "mxc://example.com/alice"
aliceDisplayName := "Alice" aliceDisplayName := "Alice"
userAPI, accountDB := MustMakeInternalAPI(t) userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "")
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
@ -106,7 +131,7 @@ func TestQueryProfile(t *testing.T) {
t.Run("HTTP API", func(t *testing.T) { t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
userapi.AddInternalRoutes(router, userAPI) AddInternalRoutes(router, userAPI)
apiURL, cancel := test.ListenAndServe(t, router, false) apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel() defer cancel()
httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{})
@ -119,3 +144,115 @@ func TestQueryProfile(t *testing.T) {
runCases(userAPI) runCases(userAPI)
}) })
} }
func TestLoginToken(t *testing.T) {
ctx := context.Background()
t.Run("tokenLoginFlow", func(t *testing.T) {
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "")
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
t.Log("Creating a login token like the SSO callback would...")
creq := api.PerformLoginTokenCreationRequest{
Data: api.LoginTokenData{UserID: "@auser:example.com"},
}
var cresp api.PerformLoginTokenCreationResponse
if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil {
t.Fatalf("PerformLoginTokenCreation failed: %v", err)
}
if cresp.Metadata.Token == "" {
t.Errorf("PerformLoginTokenCreation Token: got %q, want non-empty", cresp.Metadata.Token)
}
if cresp.Metadata.Expiration.Before(time.Now()) {
t.Errorf("PerformLoginTokenCreation Expiration: got %v, want non-expired", cresp.Metadata.Expiration)
}
t.Log("Querying the login token like /login with m.login.token would...")
qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token}
var qresp api.QueryLoginTokenResponse
if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil {
t.Fatalf("QueryLoginToken failed: %v", err)
}
if qresp.Data == nil {
t.Errorf("QueryLoginToken Data: got %v, want non-nil", qresp.Data)
} else if want := "@auser:example.com"; qresp.Data.UserID != want {
t.Errorf("QueryLoginToken UserID: got %q, want %q", qresp.Data.UserID, want)
}
t.Log("Deleting the login token like /login with m.login.token would...")
dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token}
var dresp api.PerformLoginTokenDeletionResponse
if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil {
t.Fatalf("PerformLoginTokenDeletion failed: %v", err)
}
})
t.Run("expiredTokenIsNotReturned", func(t *testing.T) {
userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second})
creq := api.PerformLoginTokenCreationRequest{
Data: api.LoginTokenData{UserID: "@auser:example.com"},
}
var cresp api.PerformLoginTokenCreationResponse
if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil {
t.Fatalf("PerformLoginTokenCreation failed: %v", err)
}
qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token}
var qresp api.QueryLoginTokenResponse
if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil {
t.Fatalf("QueryLoginToken failed: %v", err)
}
if qresp.Data != nil {
t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data)
}
})
t.Run("deleteWorks", func(t *testing.T) {
userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{})
creq := api.PerformLoginTokenCreationRequest{
Data: api.LoginTokenData{UserID: "@auser:example.com"},
}
var cresp api.PerformLoginTokenCreationResponse
if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil {
t.Fatalf("PerformLoginTokenCreation failed: %v", err)
}
dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token}
var dresp api.PerformLoginTokenDeletionResponse
if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil {
t.Fatalf("PerformLoginTokenDeletion failed: %v", err)
}
qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token}
var qresp api.QueryLoginTokenResponse
if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil {
t.Fatalf("QueryLoginToken failed: %v", err)
}
if qresp.Data != nil {
t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data)
}
})
t.Run("deleteUnknownIsNoOp", func(t *testing.T) {
userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{})
dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"}
var dresp api.PerformLoginTokenDeletionResponse
if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil {
t.Fatalf("PerformLoginTokenDeletion failed: %v", err)
}
})
}