From 53b5888bcc2249c7f0345c330dd79272ef462873 Mon Sep 17 00:00:00 2001 From: Tommie Gannert Date: Sun, 26 Sep 2021 21:53:26 +0200 Subject: [PATCH] Support login tokens in User API. This adds full lifecycle functions for login tokens: create, query, delete. --- userapi/api/api.go | 2 + userapi/api/api_logintoken.go | 68 ++++++++ userapi/api/api_trace_logintoken.go | 39 +++++ userapi/internal/api_logintoken.go | 75 ++++++++ userapi/inthttp/client_logintoken.go | 65 +++++++ userapi/inthttp/server.go | 2 + userapi/inthttp/server_logintoken.go | 68 ++++++++ userapi/storage/devices/interface.go | 11 ++ .../devices/postgres/logintoken_table.go | 93 ++++++++++ userapi/storage/devices/postgres/storage.go | 72 +++++++- .../devices/sqlite3/logintoken_table.go | 93 ++++++++++ userapi/storage/devices/sqlite3/storage.go | 75 +++++++- userapi/storage/devices/storage.go | 10 +- userapi/storage/devices/storage_wasm.go | 4 +- userapi/userapi.go | 21 ++- userapi/userapi_test.go | 165 ++++++++++++++++-- 16 files changed, 828 insertions(+), 35 deletions(-) create mode 100644 userapi/api/api_logintoken.go create mode 100644 userapi/api/api_trace_logintoken.go create mode 100644 userapi/internal/api_logintoken.go create mode 100644 userapi/inthttp/client_logintoken.go create mode 100644 userapi/inthttp/server_logintoken.go create mode 100644 userapi/storage/devices/postgres/logintoken_table.go create mode 100644 userapi/storage/devices/sqlite3/logintoken_table.go diff --git a/userapi/api/api.go b/userapi/api/api.go index 75d06dd69..e80829e90 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -24,6 +24,8 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { + LoginTokenInternalAPI + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go new file mode 100644 index 000000000..f5b7e3974 --- /dev/null +++ b/userapi/api/api_logintoken.go @@ -0,0 +1,68 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "time" +) + +type LoginTokenInternalAPI interface { + // PerformLoginTokenCreation creates a new login token and associates it with the provided data. + PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error + + // PerformLoginTokenDeletion ensures the token doesn't exist. + PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error + + // QueryLoginToken returns the data associated with a login token. If + // the token is not valid, success is returned, but res.Data == nil. + QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error +} + +// LoginTokenData is the data that can be retrieved given a login token. This is +// provided by the calling code. +type LoginTokenData struct { + // UserID is the full mxid of the user. + UserID string +} + +// LoginTokenMetadata contains metadata created and maintained by the User API. +type LoginTokenMetadata struct { + Token string + Expiration time.Time +} + +type PerformLoginTokenCreationRequest struct { + Data LoginTokenData +} + +type PerformLoginTokenCreationResponse struct { + Metadata LoginTokenMetadata +} + +type PerformLoginTokenDeletionRequest struct { + Token string +} + +type PerformLoginTokenDeletionResponse struct{} + +type QueryLoginTokenRequest struct { + Token string +} + +type QueryLoginTokenResponse struct { + // Data is nil if the token was invalid. + Data *LoginTokenData +} diff --git a/userapi/api/api_trace_logintoken.go b/userapi/api/api_trace_logintoken.go new file mode 100644 index 000000000..e60dae594 --- /dev/null +++ b/userapi/api/api_trace_logintoken.go @@ -0,0 +1,39 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/util" +) + +func (t *UserInternalAPITrace) PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error { + err := t.Impl.PerformLoginTokenCreation(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenCreation req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error { + err := t.Impl.PerformLoginTokenDeletion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenDeletion req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error { + err := t.Impl.QueryLoginToken(ctx, req, res) + util.GetLogger(ctx).Infof("QueryLoginToken req=%+v res=%+v", js(req), js(res)) + return err +} diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go new file mode 100644 index 000000000..b79ff054c --- /dev/null +++ b/userapi/internal/api_logintoken.go @@ -0,0 +1,75 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// PerformLoginTokenCreation creates a new login token and associates it with the provided data. +func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *api.PerformLoginTokenCreationRequest, res *api.PerformLoginTokenCreationResponse) error { + util.GetLogger(ctx).WithField("user_id", req.Data.UserID).Info("PerformLoginTokenCreation") + _, domain, err := gomatrixserverlib.SplitID('@', req.Data.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) + } + tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) + if err != nil { + return err + } + res.Metadata = *tokenMeta + return nil +} + +// PerformLoginTokenDeletion ensures the token doesn't exist. +func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { + util.GetLogger(ctx).Info("PerformLoginTokenDeletion") + return a.DeviceDB.RemoveLoginToken(ctx, req.Token) +} + +// QueryLoginToken returns the data associated with a login token. If +// the token is not valid, success is returned, but res.Data == nil. +func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { + tokenData, err := a.DeviceDB.GetLoginTokenByToken(ctx, req.Token) + if err != nil { + res.Data = nil + if err == sql.ErrNoRows { + return nil + } + return err + } + localpart, _, err := gomatrixserverlib.SplitID('@', tokenData.UserID) + if err != nil { + return err + } + if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { + res.Data = nil + if err == sql.ErrNoRows { + return nil + } + return err + } + res.Data = tokenData + return nil +} diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go new file mode 100644 index 000000000..366a97099 --- /dev/null +++ b/userapi/inthttp/client_logintoken.go @@ -0,0 +1,65 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/opentracing/opentracing-go" +) + +const ( + PerformLoginTokenCreationPath = "/userapi/performLoginTokenCreation" + PerformLoginTokenDeletionPath = "/userapi/performLoginTokenDeletion" + QueryLoginTokenPath = "/userapi/queryLoginToken" +) + +func (h *httpUserInternalAPI) PerformLoginTokenCreation( + ctx context.Context, + request *api.PerformLoginTokenCreationRequest, + response *api.PerformLoginTokenCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformLoginTokenDeletion( + ctx context.Context, + request *api.PerformLoginTokenDeletionRequest, + response *api.PerformLoginTokenDeletionResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryLoginToken( + ctx context.Context, + request *api.QueryLoginTokenRequest, + response *api.QueryLoginTokenResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken") + defer span.Finish() + + apiURL := h.apiURL + QueryLoginTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 1c1cfdcd1..a808aea11 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -26,6 +26,8 @@ import ( // nolint: gocyclo func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { + addRoutesLoginToken(internalAPIMux, s) + internalAPIMux.Handle(PerformAccountCreationPath, httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { request := api.PerformAccountCreationRequest{} diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go new file mode 100644 index 000000000..1f2eb34b9 --- /dev/null +++ b/userapi/inthttp/server_logintoken.go @@ -0,0 +1,68 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// addRoutesLoginToken adds routes for all login token API calls. +func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { + internalAPIMux.Handle(PerformLoginTokenCreationPath, + httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenCreationRequest{} + response := api.PerformLoginTokenCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformLoginTokenDeletionPath, + httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenDeletionRequest{} + response := api.PerformLoginTokenDeletionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryLoginTokenPath, + httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse { + request := api.QueryLoginTokenRequest{} + response := api.QueryLoginTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 95fe99f33..cb7602f0a 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -38,4 +38,15 @@ type Database interface { RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + + // CreateLoginToken generates a token, stores and returns it. The lifetime is + // determined by the loginTokenLifetime given to the Database constructor. + CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) + + // RemoveLoginToken removes the named token (and may clean up other expired tokens). + RemoveLoginToken(ctx context.Context, token string) error + + // GetLoginTokenByToken returns the data associated with the given token. + // May return sql.ErrNoRows. + GetLoginTokenByToken(ctx context.Context, token string) (*api.LoginTokenData, error) } diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/devices/postgres/logintoken_table.go new file mode 100644 index 000000000..b695b5b99 --- /dev/null +++ b/userapi/storage/devices/postgres/logintoken_table.go @@ -0,0 +1,93 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt +} + +// execSchema ensures tables and indices exist. +func (s *loginTokenStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +`) + return err +} + +// prepare runs statement preparation. +func (s *loginTokenStatements) prepare(db *sql.DB) error { + return sqlutil.StatementList{ + {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, + {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, + {&s.selectStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely garbage collected)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 485234331..881788066 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -19,6 +19,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" + "time" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -27,28 +28,38 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// The length of generated device IDs -var deviceIDByteLength = 6 +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) // Database represents a device database. type Database struct { - db *sql.DB - devices devicesStatements + db *sql.DB + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration } // NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } - d := devicesStatements{} + var d devicesStatements + var lt loginTokenStatements // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns if err = d.execSchema(db); err != nil { return nil, err } + if err = lt.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() deltas.LoadLastSeenTSIP(m) if err = m.RunDeltas(db, dbProperties); err != nil { @@ -58,8 +69,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.prepare(db, serverName); err != nil { return nil, err } + if err = lt.prepare(db); err != nil { + return nil, err + } - return &Database{db, d}, nil + return &Database{db, d, lt, loginTokenLifetime}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -210,3 +224,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/devices/sqlite3/logintoken_table.go new file mode 100644 index 000000000..96f4ee416 --- /dev/null +++ b/userapi/storage/devices/sqlite3/logintoken_table.go @@ -0,0 +1,93 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt +} + +// execSchema ensures tables and indices exist. +func (s *loginTokenStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +`) + return err +} + +// prepare runs statement preparation. +func (s *loginTokenStatements) prepare(db *sql.DB) error { + return sqlutil.StatementList{ + {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, + {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, + {&s.selectStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely garbage collected)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 538644837..7f76166b9 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -19,6 +19,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" + "time" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -27,30 +28,41 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// The length of generated device IDs -var deviceIDByteLength = 6 +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + + loginTokenByteLength = 32 +) // Database represents a device database. type Database struct { - db *sql.DB - writer sqlutil.Writer - devices devicesStatements + db *sql.DB + writer sqlutil.Writer + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration } // NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } writer := sqlutil.NewExclusiveWriter() - d := devicesStatements{} + var d devicesStatements + var lt loginTokenStatements // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns if err = d.execSchema(db); err != nil { return nil, err } + if err = lt.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() deltas.LoadLastSeenTSIP(m) if err = m.RunDeltas(db, dbProperties); err != nil { @@ -59,7 +71,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.prepare(db, writer, serverName); err != nil { return nil, err } - return &Database{db, writer, d}, nil + if err = lt.prepare(db); err != nil { + return nil, err + } + return &Database{db, writer, d, lt, loginTokenLifetime}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -210,3 +225,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go index 3c2034300..15cf8150c 100644 --- a/userapi/storage/devices/storage.go +++ b/userapi/storage/devices/storage.go @@ -19,6 +19,7 @@ package devices import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" @@ -27,13 +28,14 @@ import ( ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) { +// and sets postgres connection parameters. loginTokenLifetime determines how long a +// login token from CreateLoginToken is valid. +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) + return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName) + return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go index f360f9857..3de7880b9 100644 --- a/userapi/storage/devices/storage_wasm.go +++ b/userapi/storage/devices/storage_wasm.go @@ -16,6 +16,7 @@ package devices import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" @@ -25,10 +26,11 @@ import ( func NewDatabase( dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, + loginTokenLifetime time.Duration, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) + return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/userapi.go b/userapi/userapi.go index 74702020a..c7e1f6674 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -15,6 +15,8 @@ package userapi import ( + "time" + "github.com/gorilla/mux" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -26,6 +28,13 @@ import ( "github.com/sirupsen/logrus" ) +// defaultLoginTokenLifetime determines how old a valid token may be. +// +// NOTSPEC: The current spec says "SHOULD be limited to around five +// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. +// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). +const defaultLoginTokenLifetime = 2 * time.Minute + // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { @@ -37,11 +46,21 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { func NewInternalAPI( accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { - deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName) + deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } + return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI) +} + +func newInternalAPI( + accountDB accounts.Database, + deviceDB devices.Database, + cfg *config.UserAPI, + appServices []config.ApplicationService, + keyAPI keyapi.KeyInternalAPI, +) api.UserInternalAPI { return &internal.UserInternalAPI{ AccountDB: accountDB, DeviceDB: deviceDB, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 0141258e6..266f5ed58 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -1,4 +1,18 @@ -package userapi_test +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userapi import ( "context" @@ -6,15 +20,16 @@ import ( "net/http" "reflect" "testing" + "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" ) @@ -23,31 +38,41 @@ const ( serverName = gomatrixserverlib.ServerName("example.com") ) -func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { - accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ - ConnectionString: "file::memory:", - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) +type apiTestOpts struct { + loginTokenLifetime time.Duration +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) { + if opts.loginTokenLifetime == 0 { + opts.loginTokenLifetime = defaultLoginTokenLifetime + } + dbopts := &config.DatabaseOptions{ + ConnectionString: "file::memory:", + MaxOpenConnections: 1, + MaxIdleConnections: 1, + } + accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) if err != nil { t.Fatalf("failed to create account DB: %s", err) } + deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime) + if err != nil { + t.Fatalf("failed to create device DB: %s", err) + } + cfg := &config.UserAPI{ - DeviceDatabase: config.DatabaseOptions{ - ConnectionString: "file::memory:", - MaxOpenConnections: 1, - MaxIdleConnections: 1, - }, Matrix: &config.Global{ ServerName: serverName, }, } - return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB + return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - userAPI, accountDB := MustMakeInternalAPI(t) + userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") if err != nil { t.Fatalf("failed to make account: %s", err) @@ -106,7 +131,7 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI) + AddInternalRoutes(router, userAPI) apiURL, cancel := test.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) @@ -119,3 +144,115 @@ func TestQueryProfile(t *testing.T) { runCases(userAPI) }) } + +func TestLoginToken(t *testing.T) { + ctx := context.Background() + + t.Run("tokenLoginFlow", func(t *testing.T) { + userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) + + _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "") + if err != nil { + t.Fatalf("failed to make account: %s", err) + } + + t.Log("Creating a login token like the SSO callback would...") + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + if cresp.Metadata.Token == "" { + t.Errorf("PerformLoginTokenCreation Token: got %q, want non-empty", cresp.Metadata.Token) + } + if cresp.Metadata.Expiration.Before(time.Now()) { + t.Errorf("PerformLoginTokenCreation Expiration: got %v, want non-expired", cresp.Metadata.Expiration) + } + + t.Log("Querying the login token like /login with m.login.token would...") + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data == nil { + t.Errorf("QueryLoginToken Data: got %v, want non-nil", qresp.Data) + } else if want := "@auser:example.com"; qresp.Data.UserID != want { + t.Errorf("QueryLoginToken UserID: got %q, want %q", qresp.Data.UserID, want) + } + + t.Log("Deleting the login token like /login with m.login.token would...") + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) + + t.Run("expiredTokenIsNotReturned", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteWorks", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteUnknownIsNoOp", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) +}