Implement OpenID module (#599)

- Unrelated: change Riot references to Element in client API routing

Signed-off-by: Bruce MacDonald <contact@bruce-macdonald.com>
This commit is contained in:
Bruce MacDonald 2021-03-21 18:40:38 -07:00
parent 01267a34b9
commit 4a88a99dbc
16 changed files with 520 additions and 3 deletions

View file

@ -0,0 +1,70 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
type openIDTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
MatrixServerName string `json:"matrix_server_name"`
ExpiresIn int64 `json:"expires_in"`
}
// CreateOpenIDToken creates a new OpenID Connect (OIDC) token that a Matrix user
// can supply to an OpenID Relying Party to verify their identity
func CreateOpenIDToken(
req *http.Request,
userAPI api.UserInternalAPI,
device *api.Device,
userID string,
cfg *config.ClientAPI,
) util.JSONResponse {
// does the incoming user ID match the user that the token was issued for?
if userID != device.UserID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Cannot request tokens for other users"),
}
}
request := api.PerformOpenIDTokenCreationRequest{
UserID: userID, // this is the user ID from the incoming path
}
response := api.PerformOpenIDTokenCreationResponse{}
err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.CreateOpenIDToken failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: openIDTokenResponse{
AccessToken: response.Token.Token,
TokenType: "Bearer",
MatrixServerName: string(cfg.Matrix.ServerName),
ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s
},
}
}

View file

@ -469,7 +469,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub endpoints required by Riot
// Stub endpoints required by Element
r0mux.Handle("/login",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
@ -506,7 +506,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
// Riot user settings
// Element user settings
r0mux.Handle("/profile/{userID}",
httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
@ -592,7 +592,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
// Riot logs get flooded unless this is handled
// Element logs get flooded unless this is handled
r0mux.Handle("/presence/{userID}/status",
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
@ -685,6 +685,19 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/user/{userId}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return CreateOpenIDToken(req, userAPI, device, vars["userId"], cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {

View file

@ -0,0 +1,65 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
type openIDUserInfoResponse struct {
Sub string `json:"sub"`
}
// GetOpenIDUserInfo implements GET /_matrix/federation/v1/openid/userinfo
func GetOpenIDUserInfo(
httpReq *http.Request,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
token := httpReq.URL.Query().Get("access_token")
if len(token) == 0 {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.MissingArgument("access_token is missing"),
}
}
req := userapi.QueryOpenIDTokenRequest{
Token: token,
}
var openIDTokenAttrResponse userapi.QueryOpenIDTokenResponse
err := userAPI.QueryOpenIDToken(httpReq.Context(), &req, &openIDTokenAttrResponse)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryOpenIDToken failed")
}
var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub}
code := http.StatusOK
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresTS {
code = http.StatusUnauthorized
res = jsonerror.UnknownToken("Access Token unknown or expired")
}
return util.JSONResponse{
Code: code,
JSON: res,
}
}

View file

@ -460,4 +460,10 @@ func Setup(
return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
},
)).Methods(http.MethodPost)
v1fedmux.Handle("/openid/userinfo",
httputil.MakeExternalAPI("federation_openid_userinfo", func(req *http.Request) util.JSONResponse {
return GetOpenIDUserInfo(req, userAPI)
}),
).Methods(http.MethodGet)
}

View file

@ -524,6 +524,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
return nil
}
func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error {
return nil
}
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
return nil
}
@ -548,6 +551,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
return nil
}
func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error {
return nil
}
type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here.

View file

@ -367,6 +367,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
return nil
}
func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error {
return nil
}
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
return nil
}
@ -391,6 +394,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
return nil
}
func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error {
return nil
}
type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here.

View file

@ -517,3 +517,6 @@ AS can set avatar for ghosted users
AS can set displayname for ghosted users
Ghost user must register before joining room
Inviting an AS-hosted user asks the AS server
Can generate a openid access_token that can be exchanged for information about a user
Invalid openid access tokens are rejected
Requests to userinfo without access tokens are rejected

View file

@ -32,12 +32,14 @@ type UserInternalAPI interface {
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
}
// InputAccountDataRequest is the request for InputAccountData
@ -226,6 +228,27 @@ type PerformAccountDeactivationResponse struct {
AccountDeactivated bool
}
// PerformOpenIDTokenCreationRequest is the request for PerformOpenIDTokenCreation
type PerformOpenIDTokenCreationRequest struct {
UserID string
}
// PerformOpenIDTokenCreationResponse is the response for PerformOpenIDTokenCreation
type PerformOpenIDTokenCreationResponse struct {
Token OpenIDToken
}
// QueryOpenIDTokenRequest is the request for QueryOpenIDToken
type QueryOpenIDTokenRequest struct {
Token string
}
// QueryOpenIDTokenResponse is the response for QueryOpenIDToken
type QueryOpenIDTokenResponse struct {
Sub string // The Matrix User ID that generated the token
ExpiresTS int64
}
// Device represents a client's device (mobile, web, etc)
type Device struct {
ID string
@ -256,6 +279,24 @@ type Account struct {
// TODO: Associations (e.g. with application services)
}
// OpenIDToken represents an OpenID token
type OpenIDToken struct {
Token string
UserID string
ExpiresTS int64
}
// OpenIDTokenInfo represents the attributes associated with an issued OpenID token
type OpenIDTokenAttributes struct {
UserID string
ExpiresTS int64
}
// UserInfo is for returning information about the user an OpenID token was issued for
type UserInfo struct {
Sub string // The Matrix user's ID who generated the token
}
// ErrorForbidden is an error indicating that the supplied access token is forbidden
type ErrorForbidden struct {
Message string

View file

@ -20,6 +20,9 @@ import (
"encoding/json"
"errors"
"fmt"
"hash"
"hash/fnv"
"time"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil"
@ -414,3 +417,42 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
res.AccountDeactivated = err == nil
return err
}
// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
tokenHash := getTokenHash(token)
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
expiresMS := nowMS + (3600 * 1000) // 60 minutes
err := a.AccountDB.CreateOpenIDToken(ctx, fmt.Sprint(tokenHash.Sum32()), req.UserID, expiresMS)
res.Token = api.OpenIDToken{
Token: token,
UserID: req.UserID,
ExpiresTS: expiresMS,
}
return err
}
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
tokenHash := getTokenHash(req.Token)
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, fmt.Sprint(tokenHash.Sum32()))
if err != nil {
return err
}
res.Sub = openIDTokenAttrs.UserID
res.ExpiresTS = openIDTokenAttrs.ExpiresTS
return nil
}
// used for getting the format in which the token is stored to prevent plaintext storage of sensitive info
func getTokenHash(token string) hash.Hash32 {
tokenHash := fnv.New32a()
_, _ = tokenHash.Write([]byte(token))
return tokenHash
}

View file

@ -35,6 +35,7 @@ const (
PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
QueryProfilePath = "/userapi/queryProfile"
QueryAccessTokenPath = "/userapi/queryAccessToken"
@ -42,6 +43,7 @@ const (
QueryAccountDataPath = "/userapi/queryAccountData"
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -148,6 +150,14 @@ func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, re
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation")
defer span.Finish()
apiURL := h.apiURL + PerformOpenIDTokenCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryProfile(
ctx context.Context,
request *api.QueryProfileRequest,
@ -207,3 +217,11 @@ func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.
apiURL := h.apiURL + QuerySearchProfilesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken")
defer span.Finish()
apiURL := h.apiURL + QueryOpenIDTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -117,6 +117,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformOpenIDTokenCreationPath,
httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformOpenIDTokenCreationRequest{}
response := api.PerformOpenIDTokenCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryProfilePath,
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
request := api.QueryProfileRequest{}
@ -195,6 +208,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryOpenIDTokenPath,
httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse {
request := api.QueryOpenIDTokenRequest{}
response := api.QueryOpenIDTokenResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(InputAccountDataPath,
httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
request := api.InputAccountDataRequest{}

View file

@ -52,6 +52,8 @@ type Database interface {
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, tokenHash, localpart string, expirationTS int64) (err error)
GetOpenIDTokenAttributes(ctx context.Context, tokenHash string) (*api.OpenIDTokenAttributes, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -0,0 +1,84 @@
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const openIDTokenSchema = `
-- Stores data about openid tokens issued for accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is a hash the token value
token_hash TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL
);
`
const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
type tokenStatements struct {
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return
}
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
return
}
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
return
}
s.serverName = server
return
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
tokenHash, localpart string,
expiresTimeMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, tokenHash, localpart, expiresTimeMS)
return
}
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
ctx context.Context,
tokenHash string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
&openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS,
)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")
}
return nil, err
}
return &openIDTokenAttrs, nil
}

View file

@ -43,6 +43,7 @@ type Database struct {
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
}
@ -86,6 +87,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
return d, nil
}
@ -341,3 +345,22 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart)
}
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
tokenHash, localpart string,
expirationTS int64,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
})
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
tokenHash string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
}

View file

@ -0,0 +1,86 @@
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const openIDTokenSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is the hash of the token value
token_hash TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL
);
`
const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
type tokenStatements struct {
db *sql.DB
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return err
}
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
return
}
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
return
}
s.serverName = server
return
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
token_hash, localpart string,
expiresTimeMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token_hash, localpart, expiresTimeMS)
return
}
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
ctx context.Context,
tokenHash string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
&openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS,
)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")
}
return nil, err
}
return &openIDTokenAttrs, nil
}

View file

@ -41,6 +41,7 @@ type Database struct {
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
@ -48,6 +49,7 @@ type Database struct {
profilesMu sync.Mutex
accountDatasMu sync.Mutex
threepidsMu sync.Mutex
openIDsMu sync.Mutex
}
// NewDatabase creates a new accounts and profiles database
@ -90,6 +92,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
return d, nil
}
@ -379,3 +384,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart)
}
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
tokenHash, localpart string,
expirationTS int64,
) error {
d.openIDsMu.Lock()
defer d.openIDsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
})
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
tokenHash string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
}