feat pt 1: db working, controller not yet

This commit is contained in:
Jake Hemmerle 2020-11-04 23:51:06 -05:00
parent 35ea55e70b
commit 6c9a8b96dd
13 changed files with 452 additions and 0 deletions

View file

@ -0,0 +1,7 @@
package openid
// TokenRequest represents the request defined at https://matrix.org/docs/spec/client_server/r0.6.1#id603
type TokenRequest struct {
UserID string `json:"userId"`
RelyingParty string `json:"relyingParty"`
}

View file

@ -0,0 +1,54 @@
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
type openIDTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
MatrixServerName string `json:"matrix_server_name"`
ExpiresIn int64 `json:"expires_in"`
}
// CreateOpenIDToken creates a new OpenID Connect token that a Matrix user
// can supply to an OpenID Relying Party
func CreateOpenIDToken(
req *http.Request,
userAPI api.UserInternalAPI,
device *api.Device,
userID, relyingParty string,
cfg *config.ClientAPI,
) util.JSONResponse {
if userID != device.UserID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("userID does not match the current user"),
}
}
request := api.PerformOpenIDTokenCreationRequest{
UserID: userID,
RelyingParty: relyingParty}
response := api.PerformOpenIDTokenCreationResponse{}
err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response)
if err != nil {
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: openIDTokenResponse{
AccessToken: response.Token.Token,
TokenType: "Bearer",
MatrixServerName: string(cfg.Matrix.ServerName),
ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s
},
}
}

View file

@ -661,6 +661,19 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/user/{userId}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return CreateOpenIDToken(req, userAPI, device, vars["userId"], vars["relyingParty"], cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user_directory/search", r0mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil { if r := rateLimits.rateLimit(req); r != nil {

View file

@ -0,0 +1,14 @@
package routing
/*
import (
"net/http"
)
func GetInvite(
httpReq *http.Request,
) {
return
}
*/

View file

@ -31,12 +31,14 @@ type UserInternalAPI interface {
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
} }
// InputAccountDataRequest is the request for InputAccountData // InputAccountDataRequest is the request for InputAccountData
@ -214,6 +216,27 @@ type PerformAccountDeactivationResponse struct {
AccountDeactivated bool AccountDeactivated bool
} }
// PerformOpenIDTokenCreationRequest is the request for PerformOpenIDTokenCreation
type PerformOpenIDTokenCreationRequest struct {
UserID string
RelyingParty string
}
// PerformOpenIDTokenCreationResponse is the response for PerformOpenIDTokenCreation
type PerformOpenIDTokenCreationResponse struct {
Token OpenIDToken
}
// QueryOpenIDTokenRequest is the request for QueryOpenIDToken
type QueryOpenIDTokenRequest struct {
Token string
}
// QueryOpenIDTokenResponse is the response for QueryOpenIDToken
type QueryOpenIDTokenResponse struct {
Token OpenIDToken
}
// Device represents a client's device (mobile, web, etc) // Device represents a client's device (mobile, web, etc)
type Device struct { type Device struct {
ID string ID string
@ -241,6 +264,15 @@ type Account struct {
// TODO: Associations (e.g. with application services) // TODO: Associations (e.g. with application services)
} }
// OpenIDToken represents an OpenID token
type OpenIDToken struct {
Token string
UserID string
CreatedTS int64
ExpiresTS int64
RelyingParty string
}
// ErrorForbidden is an error indicating that the supplied access token is forbidden // ErrorForbidden is an error indicating that the supplied access token is forbidden
type ErrorForbidden struct { type ErrorForbidden struct {
Message string Message string

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
@ -395,3 +396,43 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
res.AccountDeactivated = err == nil res.AccountDeactivated = err == nil
return err return err
} }
// PerformOpenIDTokenCreation creates a new token from an optional relying party.
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return err
}
if domain != a.ServerName {
return fmt.Errorf("cannot create OpenID token for accounts not on this sercer: got %s want %s", domain, a.ServerName)
}
tokenString := util.RandomString(24)
createdMS := time.Now().UnixNano() / int64(time.Millisecond)
expiresMS := createdMS + (3600 * 1000) // 60 minutes
err = a.AccountDB.CreateOpenIDToken(ctx, tokenString, localpart, createdMS, expiresMS, req.RelyingParty)
res.Token = api.OpenIDToken{
Token: tokenString,
UserID: userutil.MakeUserID(localpart, a.ServerName),
CreatedTS: createdMS,
ExpiresTS: expiresMS,
RelyingParty: req.RelyingParty,
}
return err
}
// QueryOpenIDToken returns the user information from the provided token string
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
token, err := a.AccountDB.GetOpenIDToken(ctx, req.Token)
if err != nil {
return err
}
res.Token = *token
return nil
}

View file

@ -34,6 +34,7 @@ const (
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
QueryAccessTokenPath = "/userapi/queryAccessToken" QueryAccessTokenPath = "/userapi/queryAccessToken"
@ -41,6 +42,7 @@ const (
QueryAccountDataPath = "/userapi/queryAccountData" QueryAccountDataPath = "/userapi/queryAccountData"
QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles" QuerySearchProfilesPath = "/userapi/querySearchProfiles"
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -135,6 +137,14 @@ func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, re
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation")
defer span.Finish()
apiURL := h.apiURL + PerformOpenIDTokenCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryProfile( func (h *httpUserInternalAPI) QueryProfile(
ctx context.Context, ctx context.Context,
request *api.QueryProfileRequest, request *api.QueryProfileRequest,
@ -194,3 +204,11 @@ func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.
apiURL := h.apiURL + QuerySearchProfilesPath apiURL := h.apiURL + QuerySearchProfilesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken")
defer span.Finish()
apiURL := h.apiURL + QueryOpenIDTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -104,6 +104,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(PerformOpenIDTokenCreationPath,
httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformOpenIDTokenCreationRequest{}
response := api.PerformOpenIDTokenCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryProfilePath, internalAPIMux.Handle(QueryProfilePath,
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
request := api.QueryProfileRequest{} request := api.QueryProfileRequest{}
@ -182,6 +195,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryOpenIDTokenPath,
httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse {
request := api.QueryOpenIDTokenRequest{}
response := api.QueryOpenIDTokenResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(InputAccountDataPath, internalAPIMux.Handle(InputAccountDataPath,
httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
request := api.InputAccountDataRequest{} request := api.InputAccountDataRequest{}

View file

@ -52,6 +52,8 @@ type Database interface {
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, token, localpart string, creationTS, expirationTS int64, rp string) (err error)
GetOpenIDToken(ctx context.Context, token string) (*api.OpenIDToken, error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -0,0 +1,95 @@
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
const openIDTokenSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_openid (
-- This is the token value, empty by default
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- When this token was first created, as a unix timestamp (ms resolution).
token_created_ts BIGINT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL,
-- (optional) Relying Party the token was created for
token_rp TEXT,
);
-- Create sequence for autogenerated numeric usernames
-- CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
`
const insertTokenSQL = "" +
"INSERT INTO account_openid(token, localpart, token_created_ts, token_expires_ts, token_rp) VALUES ($1, $2, $3, $4, $5)"
const selectTokenSQL = "" +
"SELECT token, localpart, token_created_ts, token_expires_ts, token_rp FROM account_openid WHERE token = $1"
type tokenStatements struct {
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return
}
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
return
}
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
return
}
s.serverName = server
return
}
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
createdTimeMS, expiresTimeMS int64,
tokenRP string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
if tokenRP == "" {
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, nil)
} else {
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, tokenRP)
}
return
}
func (s *tokenStatements) selectToken(
ctx context.Context,
token string,
) (*api.OpenIDToken, error) {
var openIDToken api.OpenIDToken
var localpart string
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDToken.Token,
localpart,
&openIDToken.CreatedTS,
&openIDToken.ExpiresTS,
&openIDToken.RelyingParty,
)
if err != nil {
return nil, err
}
openIDToken.UserID = userutil.MakeUserID(localpart, s.serverName)
return &openIDToken, nil
}

View file

@ -43,6 +43,7 @@ type Database struct {
profiles profilesStatements profiles profilesStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -84,6 +85,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.threepids.prepare(db); err != nil { if err = d.threepids.prepare(db); err != nil {
return nil, err return nil, err
} }
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
return d, nil return d, nil
} }
@ -337,3 +341,23 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart) return d.accounts.deactivateAccount(ctx, localpart)
} }
// CreateOpenIDToken creates a new token for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
createdTS, expirationTS int64,
rp string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, createdTS, expirationTS, rp)
})
}
// GetOpenIDToken gets a whole token
func (d *Database) GetOpenIDToken(
ctx context.Context,
token string,
) (*api.OpenIDToken, error) {
return d.openIDTokens.selectToken(ctx, token)
}

View file

@ -0,0 +1,99 @@
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
const openIDTokenSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_openid (
-- This is the token value, empty by default
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- When this token was first created, as a unix timestamp (ms resolution).
token_created_ts BIGINT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL,
-- (optional) Relying Party the token was created for
token_rp TEXT,
);
-- Create sequence for autogenerated numeric usernames
-- CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
`
const insertTokenSQL = "" +
"INSERT INTO account_openid(token, localpart, token_created_ts, token_expires_ts, token_rp) VALUES ($1, $2, $3, $4, $5)"
const selectTokenSQL = "" +
"SELECT token, localpart, token_created_ts, token_expires_ts, token_rp FROM account_openid WHERE token = $1"
type tokenStatements struct {
db *sql.DB
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return err
}
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
return
}
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
return
}
s.serverName = server
return
}
// insertToken inserts a new OpenID Connect token to the DB.
// tokenRP is the OpenID Relying Party; if not specified, it's left nil
// Returns new token, otherwise returns error if token already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
createdTimeMS, expiresTimeMS int64,
tokenRP string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
if tokenRP == "" {
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, nil)
} else {
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, tokenRP)
}
return
}
func (s *tokenStatements) selectToken(
ctx context.Context,
token string,
) (openIDToken *api.OpenIDToken, err error) {
var localpart string
err = s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDToken.Token,
localpart,
&openIDToken.CreatedTS,
&openIDToken.ExpiresTS,
&openIDToken.RelyingParty,
)
if err != nil {
return nil, err
}
openIDToken.UserID = userutil.MakeUserID(localpart, s.serverName)
return openIDToken, nil
}

View file

@ -42,12 +42,14 @@ type Database struct {
profiles profilesStatements profiles profilesStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
accountsMu sync.Mutex accountsMu sync.Mutex
profilesMu sync.Mutex profilesMu sync.Mutex
accountDatasMu sync.Mutex accountDatasMu sync.Mutex
threepidsMu sync.Mutex threepidsMu sync.Mutex
openIDsMu sync.Mutex
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
@ -89,6 +91,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.threepids.prepare(db); err != nil { if err = d.threepids.prepare(db); err != nil {
return nil, err return nil, err
} }
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
return d, nil return d, nil
} }
@ -378,3 +383,25 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart) return d.accounts.deactivateAccount(ctx, localpart)
} }
// CreateOpenIDToken creates a new token for the
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
createdTS, expirationTS int64,
rp string,
) error {
d.openIDsMu.Lock()
defer d.openIDsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, createdTS, expirationTS, rp)
})
}
// GetOpenIDToken returns an OpenIDToken struct from the token string
func (d *Database) GetOpenIDToken(
ctx context.Context,
token string,
) (*api.OpenIDToken, error) {
return d.openIDTokens.selectToken(ctx, token)
}