mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 00:03:09 -06:00
feat pt 1: db working, controller not yet
This commit is contained in:
parent
35ea55e70b
commit
6c9a8b96dd
7
clientapi/openid/openid.go
Normal file
7
clientapi/openid/openid.go
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
package openid
|
||||||
|
|
||||||
|
// TokenRequest represents the request defined at https://matrix.org/docs/spec/client_server/r0.6.1#id603
|
||||||
|
type TokenRequest struct {
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
RelyingParty string `json:"relyingParty"`
|
||||||
|
}
|
||||||
54
clientapi/routing/openid.go
Normal file
54
clientapi/routing/openid.go
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openIDTokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
MatrixServerName string `json:"matrix_server_name"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOpenIDToken creates a new OpenID Connect token that a Matrix user
|
||||||
|
// can supply to an OpenID Relying Party
|
||||||
|
func CreateOpenIDToken(
|
||||||
|
req *http.Request,
|
||||||
|
userAPI api.UserInternalAPI,
|
||||||
|
device *api.Device,
|
||||||
|
userID, relyingParty string,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
if userID != device.UserID {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden("userID does not match the current user"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request := api.PerformOpenIDTokenCreationRequest{
|
||||||
|
UserID: userID,
|
||||||
|
RelyingParty: relyingParty}
|
||||||
|
response := api.PerformOpenIDTokenCreationResponse{}
|
||||||
|
|
||||||
|
err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: openIDTokenResponse{
|
||||||
|
AccessToken: response.Token.Token,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
MatrixServerName: string(cfg.Matrix.ServerName),
|
||||||
|
ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -661,6 +661,19 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet)
|
).Methods(http.MethodGet)
|
||||||
|
|
||||||
|
r0mux.Handle("/user/{userId}/openid/request_token",
|
||||||
|
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return CreateOpenIDToken(req, userAPI, device, vars["userId"], vars["relyingParty"], cfg)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
r0mux.Handle("/user_directory/search",
|
r0mux.Handle("/user_directory/search",
|
||||||
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if r := rateLimits.rateLimit(req); r != nil {
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
|
|
||||||
14
federationapi/routing/openid.go
Normal file
14
federationapi/routing/openid.go
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetInvite(
|
||||||
|
httpReq *http.Request,
|
||||||
|
) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
@ -31,12 +31,14 @@ type UserInternalAPI interface {
|
||||||
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
|
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
|
||||||
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||||
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
|
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
|
||||||
|
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
||||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||||
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
|
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
|
||||||
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
|
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
|
||||||
|
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// InputAccountDataRequest is the request for InputAccountData
|
// InputAccountDataRequest is the request for InputAccountData
|
||||||
|
|
@ -214,6 +216,27 @@ type PerformAccountDeactivationResponse struct {
|
||||||
AccountDeactivated bool
|
AccountDeactivated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PerformOpenIDTokenCreationRequest is the request for PerformOpenIDTokenCreation
|
||||||
|
type PerformOpenIDTokenCreationRequest struct {
|
||||||
|
UserID string
|
||||||
|
RelyingParty string
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformOpenIDTokenCreationResponse is the response for PerformOpenIDTokenCreation
|
||||||
|
type PerformOpenIDTokenCreationResponse struct {
|
||||||
|
Token OpenIDToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryOpenIDTokenRequest is the request for QueryOpenIDToken
|
||||||
|
type QueryOpenIDTokenRequest struct {
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryOpenIDTokenResponse is the response for QueryOpenIDToken
|
||||||
|
type QueryOpenIDTokenResponse struct {
|
||||||
|
Token OpenIDToken
|
||||||
|
}
|
||||||
|
|
||||||
// Device represents a client's device (mobile, web, etc)
|
// Device represents a client's device (mobile, web, etc)
|
||||||
type Device struct {
|
type Device struct {
|
||||||
ID string
|
ID string
|
||||||
|
|
@ -241,6 +264,15 @@ type Account struct {
|
||||||
// TODO: Associations (e.g. with application services)
|
// TODO: Associations (e.g. with application services)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenIDToken represents an OpenID token
|
||||||
|
type OpenIDToken struct {
|
||||||
|
Token string
|
||||||
|
UserID string
|
||||||
|
CreatedTS int64
|
||||||
|
ExpiresTS int64
|
||||||
|
RelyingParty string
|
||||||
|
}
|
||||||
|
|
||||||
// ErrorForbidden is an error indicating that the supplied access token is forbidden
|
// ErrorForbidden is an error indicating that the supplied access token is forbidden
|
||||||
type ErrorForbidden struct {
|
type ErrorForbidden struct {
|
||||||
Message string
|
Message string
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/appservice/types"
|
"github.com/matrix-org/dendrite/appservice/types"
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
|
@ -395,3 +396,43 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
||||||
res.AccountDeactivated = err == nil
|
res.AccountDeactivated = err == nil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PerformOpenIDTokenCreation creates a new token from an optional relying party.
|
||||||
|
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
|
||||||
|
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if domain != a.ServerName {
|
||||||
|
return fmt.Errorf("cannot create OpenID token for accounts not on this sercer: got %s want %s", domain, a.ServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenString := util.RandomString(24)
|
||||||
|
createdMS := time.Now().UnixNano() / int64(time.Millisecond)
|
||||||
|
expiresMS := createdMS + (3600 * 1000) // 60 minutes
|
||||||
|
|
||||||
|
err = a.AccountDB.CreateOpenIDToken(ctx, tokenString, localpart, createdMS, expiresMS, req.RelyingParty)
|
||||||
|
|
||||||
|
res.Token = api.OpenIDToken{
|
||||||
|
Token: tokenString,
|
||||||
|
UserID: userutil.MakeUserID(localpart, a.ServerName),
|
||||||
|
CreatedTS: createdMS,
|
||||||
|
ExpiresTS: expiresMS,
|
||||||
|
RelyingParty: req.RelyingParty,
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryOpenIDToken returns the user information from the provided token string
|
||||||
|
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
|
||||||
|
token, err := a.AccountDB.GetOpenIDToken(ctx, req.Token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Token = *token
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ const (
|
||||||
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
|
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
|
||||||
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
|
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
|
||||||
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
|
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
|
||||||
|
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
|
||||||
|
|
||||||
QueryProfilePath = "/userapi/queryProfile"
|
QueryProfilePath = "/userapi/queryProfile"
|
||||||
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
||||||
|
|
@ -41,6 +42,7 @@ const (
|
||||||
QueryAccountDataPath = "/userapi/queryAccountData"
|
QueryAccountDataPath = "/userapi/queryAccountData"
|
||||||
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
||||||
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
|
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
|
||||||
|
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
|
@ -135,6 +137,14 @@ func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, re
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformOpenIDTokenCreationPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpUserInternalAPI) QueryProfile(
|
func (h *httpUserInternalAPI) QueryProfile(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryProfileRequest,
|
request *api.QueryProfileRequest,
|
||||||
|
|
@ -194,3 +204,11 @@ func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.
|
||||||
apiURL := h.apiURL + QuerySearchProfilesPath
|
apiURL := h.apiURL + QuerySearchProfilesPath
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryOpenIDTokenPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -104,6 +104,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(PerformOpenIDTokenCreationPath,
|
||||||
|
httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.PerformOpenIDTokenCreationRequest{}
|
||||||
|
response := api.PerformOpenIDTokenCreationResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(QueryProfilePath,
|
internalAPIMux.Handle(QueryProfilePath,
|
||||||
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.QueryProfileRequest{}
|
request := api.QueryProfileRequest{}
|
||||||
|
|
@ -182,6 +195,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(QueryOpenIDTokenPath,
|
||||||
|
httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryOpenIDTokenRequest{}
|
||||||
|
response := api.QueryOpenIDTokenResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(InputAccountDataPath,
|
internalAPIMux.Handle(InputAccountDataPath,
|
||||||
httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.InputAccountDataRequest{}
|
request := api.InputAccountDataRequest{}
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,8 @@ type Database interface {
|
||||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||||
|
CreateOpenIDToken(ctx context.Context, token, localpart string, creationTS, expirationTS int64, rp string) (err error)
|
||||||
|
GetOpenIDToken(ctx context.Context, token string) (*api.OpenIDToken, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||||
|
|
|
||||||
95
userapi/storage/accounts/postgres/openid_table.go
Normal file
95
userapi/storage/accounts/postgres/openid_table.go
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const openIDTokenSchema = `
|
||||||
|
-- Stores data about accounts.
|
||||||
|
CREATE TABLE IF NOT EXISTS account_openid (
|
||||||
|
-- This is the token value, empty by default
|
||||||
|
token TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The Matrix user ID localpart for this account
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
-- When this token was first created, as a unix timestamp (ms resolution).
|
||||||
|
token_created_ts BIGINT NOT NULL,
|
||||||
|
-- When the token expires, as a unix timestamp (ms resolution).
|
||||||
|
token_expires_ts BIGINT NOT NULL,
|
||||||
|
-- (optional) Relying Party the token was created for
|
||||||
|
token_rp TEXT,
|
||||||
|
);
|
||||||
|
-- Create sequence for autogenerated numeric usernames
|
||||||
|
-- CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertTokenSQL = "" +
|
||||||
|
"INSERT INTO account_openid(token, localpart, token_created_ts, token_expires_ts, token_rp) VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
|
const selectTokenSQL = "" +
|
||||||
|
"SELECT token, localpart, token_created_ts, token_expires_ts, token_rp FROM account_openid WHERE token = $1"
|
||||||
|
|
||||||
|
type tokenStatements struct {
|
||||||
|
insertTokenStmt *sql.Stmt
|
||||||
|
selectTokenStmt *sql.Stmt
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
_, err = db.Exec(openIDTokenSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.serverName = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStatements) insertToken(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
token, localpart string,
|
||||||
|
createdTimeMS, expiresTimeMS int64,
|
||||||
|
tokenRP string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||||
|
|
||||||
|
if tokenRP == "" {
|
||||||
|
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, nil)
|
||||||
|
} else {
|
||||||
|
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, tokenRP)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStatements) selectToken(
|
||||||
|
ctx context.Context,
|
||||||
|
token string,
|
||||||
|
) (*api.OpenIDToken, error) {
|
||||||
|
var openIDToken api.OpenIDToken
|
||||||
|
var localpart string
|
||||||
|
|
||||||
|
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||||
|
&openIDToken.Token,
|
||||||
|
localpart,
|
||||||
|
&openIDToken.CreatedTS,
|
||||||
|
&openIDToken.ExpiresTS,
|
||||||
|
&openIDToken.RelyingParty,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
openIDToken.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
return &openIDToken, nil
|
||||||
|
}
|
||||||
|
|
@ -43,6 +43,7 @@ type Database struct {
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
accountDatas accountDataStatements
|
accountDatas accountDataStatements
|
||||||
threepids threepidStatements
|
threepids threepidStatements
|
||||||
|
openIDTokens tokenStatements
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -84,6 +85,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = d.threepids.prepare(db); err != nil {
|
if err = d.threepids.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
@ -337,3 +341,23 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
||||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||||
return d.accounts.deactivateAccount(ctx, localpart)
|
return d.accounts.deactivateAccount(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateOpenIDToken creates a new token for OpenID Connect
|
||||||
|
func (d *Database) CreateOpenIDToken(
|
||||||
|
ctx context.Context,
|
||||||
|
token, localpart string,
|
||||||
|
createdTS, expirationTS int64,
|
||||||
|
rp string,
|
||||||
|
) error {
|
||||||
|
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.openIDTokens.insertToken(ctx, txn, token, localpart, createdTS, expirationTS, rp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenIDToken gets a whole token
|
||||||
|
func (d *Database) GetOpenIDToken(
|
||||||
|
ctx context.Context,
|
||||||
|
token string,
|
||||||
|
) (*api.OpenIDToken, error) {
|
||||||
|
return d.openIDTokens.selectToken(ctx, token)
|
||||||
|
}
|
||||||
|
|
|
||||||
99
userapi/storage/accounts/sqlite3/openid_table.go
Normal file
99
userapi/storage/accounts/sqlite3/openid_table.go
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const openIDTokenSchema = `
|
||||||
|
-- Stores data about accounts.
|
||||||
|
CREATE TABLE IF NOT EXISTS account_openid (
|
||||||
|
-- This is the token value, empty by default
|
||||||
|
token TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The Matrix user ID localpart for this account
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
-- When this token was first created, as a unix timestamp (ms resolution).
|
||||||
|
token_created_ts BIGINT NOT NULL,
|
||||||
|
-- When the token expires, as a unix timestamp (ms resolution).
|
||||||
|
token_expires_ts BIGINT NOT NULL,
|
||||||
|
-- (optional) Relying Party the token was created for
|
||||||
|
token_rp TEXT,
|
||||||
|
);
|
||||||
|
-- Create sequence for autogenerated numeric usernames
|
||||||
|
-- CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertTokenSQL = "" +
|
||||||
|
"INSERT INTO account_openid(token, localpart, token_created_ts, token_expires_ts, token_rp) VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
|
const selectTokenSQL = "" +
|
||||||
|
"SELECT token, localpart, token_created_ts, token_expires_ts, token_rp FROM account_openid WHERE token = $1"
|
||||||
|
|
||||||
|
type tokenStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
insertTokenStmt *sql.Stmt
|
||||||
|
selectTokenStmt *sql.Stmt
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
s.db = db
|
||||||
|
_, err = db.Exec(openIDTokenSchema)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.serverName = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertToken inserts a new OpenID Connect token to the DB.
|
||||||
|
// tokenRP is the OpenID Relying Party; if not specified, it's left nil
|
||||||
|
// Returns new token, otherwise returns error if token already exists.
|
||||||
|
func (s *tokenStatements) insertToken(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
token, localpart string,
|
||||||
|
createdTimeMS, expiresTimeMS int64,
|
||||||
|
tokenRP string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||||
|
|
||||||
|
if tokenRP == "" {
|
||||||
|
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, nil)
|
||||||
|
} else {
|
||||||
|
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, tokenRP)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStatements) selectToken(
|
||||||
|
ctx context.Context,
|
||||||
|
token string,
|
||||||
|
) (openIDToken *api.OpenIDToken, err error) {
|
||||||
|
var localpart string
|
||||||
|
|
||||||
|
err = s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||||
|
&openIDToken.Token,
|
||||||
|
localpart,
|
||||||
|
&openIDToken.CreatedTS,
|
||||||
|
&openIDToken.ExpiresTS,
|
||||||
|
&openIDToken.RelyingParty,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
openIDToken.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
return openIDToken, nil
|
||||||
|
}
|
||||||
|
|
@ -42,12 +42,14 @@ type Database struct {
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
accountDatas accountDataStatements
|
accountDatas accountDataStatements
|
||||||
threepids threepidStatements
|
threepids threepidStatements
|
||||||
|
openIDTokens tokenStatements
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
|
|
||||||
accountsMu sync.Mutex
|
accountsMu sync.Mutex
|
||||||
profilesMu sync.Mutex
|
profilesMu sync.Mutex
|
||||||
accountDatasMu sync.Mutex
|
accountDatasMu sync.Mutex
|
||||||
threepidsMu sync.Mutex
|
threepidsMu sync.Mutex
|
||||||
|
openIDsMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
|
|
@ -89,6 +91,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = d.threepids.prepare(db); err != nil {
|
if err = d.threepids.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
@ -378,3 +383,25 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
||||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||||
return d.accounts.deactivateAccount(ctx, localpart)
|
return d.accounts.deactivateAccount(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateOpenIDToken creates a new token for the
|
||||||
|
func (d *Database) CreateOpenIDToken(
|
||||||
|
ctx context.Context,
|
||||||
|
token, localpart string,
|
||||||
|
createdTS, expirationTS int64,
|
||||||
|
rp string,
|
||||||
|
) error {
|
||||||
|
d.openIDsMu.Lock()
|
||||||
|
defer d.openIDsMu.Unlock()
|
||||||
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.openIDTokens.insertToken(ctx, txn, token, localpart, createdTS, expirationTS, rp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenIDToken returns an OpenIDToken struct from the token string
|
||||||
|
func (d *Database) GetOpenIDToken(
|
||||||
|
ctx context.Context,
|
||||||
|
token string,
|
||||||
|
) (*api.OpenIDToken, error) {
|
||||||
|
return d.openIDTokens.selectToken(ctx, token)
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue