mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-08 22:53:10 -06:00
Add SSO UserAPI endpoints.
This is mostly copied from the ThirdPID, but with a primary key that matches OpenID Connect nomenclature. There's a namspace to ensure other SSO solutions can be supported, but there's only one namespace defined for now.
This commit is contained in:
parent
c9ad7206c8
commit
c3f7945284
|
|
@ -77,6 +77,7 @@ type SyncUserAPI interface {
|
||||||
type ClientUserAPI interface {
|
type ClientUserAPI interface {
|
||||||
QueryAcccessTokenAPI
|
QueryAcccessTokenAPI
|
||||||
LoginTokenInternalAPI
|
LoginTokenInternalAPI
|
||||||
|
SSOAPI
|
||||||
UserLoginAPI
|
UserLoginAPI
|
||||||
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
|
|
|
||||||
53
userapi/api/api_sso.go
Normal file
53
userapi/api/api_sso.go
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SSOAPI interface {
|
||||||
|
QueryLocalpartForSSO(ctx context.Context, req *QueryLocalpartForSSORequest, res *QueryLocalpartForSSOResponse) error
|
||||||
|
PerformForgetSSO(ctx context.Context, req *PerformForgetSSORequest, res *struct{}) error
|
||||||
|
PerformSaveSSOAssociation(ctx context.Context, req *PerformSaveSSOAssociationRequest, res *struct{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryLocalpartForSSORequest struct {
|
||||||
|
Namespace SSOIssuerNamespace
|
||||||
|
Issuer, Subject string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryLocalpartForSSOResponse struct {
|
||||||
|
Localpart string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PerformForgetSSORequest QueryLocalpartForSSORequest
|
||||||
|
|
||||||
|
type PerformSaveSSOAssociationRequest struct {
|
||||||
|
Namespace SSOIssuerNamespace
|
||||||
|
Issuer, Subject string
|
||||||
|
Localpart string
|
||||||
|
}
|
||||||
|
|
||||||
|
// An SSOIssuerNamespace defines the interpretation of an issuer.
|
||||||
|
type SSOIssuerNamespace string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UnknownNamespace SSOIssuerNamespace = ""
|
||||||
|
|
||||||
|
// OIDCNamespace indicates the issuer is a full URL, as defined in
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#Terminology.
|
||||||
|
OIDCNamespace SSOIssuerNamespace = "oidc"
|
||||||
|
)
|
||||||
39
userapi/api/api_trace_sso.go
Normal file
39
userapi/api/api_trace_sso.go
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) QueryLocalpartForSSO(ctx context.Context, req *QueryLocalpartForSSORequest, res *QueryLocalpartForSSOResponse) error {
|
||||||
|
err := t.Impl.QueryLocalpartForSSO(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("QueryLocalpartForSSO req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) PerformForgetSSO(ctx context.Context, req *PerformForgetSSORequest, res *struct{}) error {
|
||||||
|
err := t.Impl.PerformForgetSSO(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("PerformForgetSSO req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) PerformSaveSSOAssociation(ctx context.Context, req *PerformSaveSSOAssociationRequest, res *struct{}) error {
|
||||||
|
err := t.Impl.PerformSaveSSOAssociation(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("PerformSaveSSOAssociation req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
50
userapi/internal/api_sso.go
Normal file
50
userapi/internal/api_sso.go
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) QueryLocalpartForSSO(ctx context.Context, req *api.QueryLocalpartForSSORequest, res *api.QueryLocalpartForSSOResponse) error {
|
||||||
|
var err error
|
||||||
|
res.Localpart, err = a.DB.GetLocalpartForSSO(ctx, string(req.Namespace), req.Issuer, req.Subject)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformForgetSSO(ctx context.Context, req *api.PerformForgetSSORequest, res *struct{}) error {
|
||||||
|
return a.DB.RemoveSSOAssociation(ctx, string(req.Namespace), req.Issuer, req.Subject)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformSaveSSOAssociation(ctx context.Context, req *api.PerformSaveSSOAssociationRequest, res *struct{}) error {
|
||||||
|
ns, err := validateSSOIssuerNamespace(req.Namespace)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return a.DB.SaveSSOAssociation(ctx, ns, req.Issuer, req.Subject, req.Localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSSOIssuerNamespace(ns api.SSOIssuerNamespace) (string, error) {
|
||||||
|
switch ns {
|
||||||
|
case api.OIDCNamespace:
|
||||||
|
return string(ns), nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("invalid SSO issuer namespace: %s", ns)
|
||||||
|
}
|
||||||
|
}
|
||||||
53
userapi/inthttp/client_sso.go
Normal file
53
userapi/inthttp/client_sso.go
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package inthttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/opentracing/opentracing-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PerformForgetSSOPath = "/userapi/performForgetSSO"
|
||||||
|
PerformSaveSSOAssociationPath = "/userapi/performSaveSSOAssociation"
|
||||||
|
QueryLocalpartForSSOPath = "/userapi/queryLocalpartForSSO"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) QueryLocalpartForSSO(ctx context.Context, req *api.QueryLocalpartForSSORequest, res *api.QueryLocalpartForSSOResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForSSOPath)
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryLocalpartForSSOPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) PerformForgetSSO(ctx context.Context, req *api.PerformForgetSSORequest, res *struct{}) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetSSOPath)
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformForgetSSOPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) PerformSaveSSOAssociation(ctx context.Context, req *api.PerformSaveSSOAssociationRequest, res *struct{}) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveSSOAssociationPath)
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformSaveSSOAssociationPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
@ -28,6 +28,7 @@ import (
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
addRoutesLoginToken(internalAPIMux, s)
|
addRoutesLoginToken(internalAPIMux, s)
|
||||||
|
addRoutesSSO(internalAPIMux, s)
|
||||||
|
|
||||||
internalAPIMux.Handle(PerformAccountCreationPath,
|
internalAPIMux.Handle(PerformAccountCreationPath,
|
||||||
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
|
||||||
|
|
|
||||||
66
userapi/inthttp/server_sso.go
Normal file
66
userapi/inthttp/server_sso.go
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package inthttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// addRoutesSSO adds routes for all SSO API calls.
|
||||||
|
func addRoutesSSO(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
|
internalAPIMux.Handle(QueryLocalpartForSSOPath,
|
||||||
|
httputil.MakeInternalAPI("queryLocalpartForSSO", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryLocalpartForSSORequest{}
|
||||||
|
response := api.QueryLocalpartForSSOResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.QueryLocalpartForSSO(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(PerformForgetSSOPath,
|
||||||
|
httputil.MakeInternalAPI("performForgetSSO", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.PerformForgetSSORequest{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.PerformForgetSSO(req.Context(), &request, &struct{}{}); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(PerformSaveSSOAssociationPath,
|
||||||
|
httputil.MakeInternalAPI("performSaveSSOAssociation", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.PerformSaveSSOAssociationRequest{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.PerformSaveSSOAssociation(req.Context(), &request, &struct{}{}); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
@ -109,6 +109,12 @@ type Pusher interface {
|
||||||
RemovePushers(ctx context.Context, appid, pushkey string) error
|
RemovePushers(ctx context.Context, appid, pushkey string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SSO interface {
|
||||||
|
SaveSSOAssociation(ctx context.Context, namespace, iss, sub, localpart string) error
|
||||||
|
RemoveSSOAssociation(ctx context.Context, namespace, iss, sub string) error
|
||||||
|
GetLocalpartForSSO(ctx context.Context, namespace, iss, sub string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
type ThreePID interface {
|
type ThreePID interface {
|
||||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||||
|
|
@ -136,6 +142,7 @@ type Database interface {
|
||||||
OpenID
|
OpenID
|
||||||
Profile
|
Profile
|
||||||
Pusher
|
Pusher
|
||||||
|
SSO
|
||||||
Statistics
|
Statistics
|
||||||
ThreePID
|
ThreePID
|
||||||
}
|
}
|
||||||
|
|
|
||||||
93
userapi/storage/postgres/sso_table.go
Normal file
93
userapi/storage/postgres/sso_table.go
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ssoSchema = `
|
||||||
|
-- Stores data about SSO associations.
|
||||||
|
CREATE TABLE IF NOT EXISTS account_sso (
|
||||||
|
-- The "iss" namespace. Must be "oidc".
|
||||||
|
namespace TEXT NOT NULL,
|
||||||
|
-- The issuer; for "oidc", a URL.
|
||||||
|
iss TEXT NOT NULL,
|
||||||
|
-- The subject (user ID).
|
||||||
|
sub TEXT NOT NULL,
|
||||||
|
-- The localpart of the Matrix user ID associated to this 3PID
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY(namespace, iss, sub)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectLocalpartForSSOSQL = "" +
|
||||||
|
"SELECT localpart FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
|
||||||
|
|
||||||
|
const insertSSOSQL = "" +
|
||||||
|
"INSERT INTO account_sso (namespace, iss, sub, localpart) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
|
const deleteSSOSQL = "" +
|
||||||
|
"DELETE FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
|
||||||
|
|
||||||
|
type ssoStatements struct {
|
||||||
|
selectLocalpartForSSOStmt *sql.Stmt
|
||||||
|
insertSSOStmt *sql.Stmt
|
||||||
|
deleteSSOStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresSSOTable(db *sql.DB) (tables.SSOTable, error) {
|
||||||
|
s := &ssoStatements{}
|
||||||
|
_, err := db.Exec(ssoSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.selectLocalpartForSSOStmt, selectLocalpartForSSOSQL},
|
||||||
|
{&s.insertSSOStmt, insertSSOSQL},
|
||||||
|
{&s.deleteSSOStmt, deleteSSOSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ssoStatements) SelectLocalpartForSSO(
|
||||||
|
ctx context.Context, txn *sql.Tx, namespace, iss, sub string,
|
||||||
|
) (localpart string, err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForSSOStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, namespace, iss, sub).Scan(&localpart)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ssoStatements) InsertSSO(
|
||||||
|
ctx context.Context, txn *sql.Tx, namespace, iss, sub, localpart string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertSSOStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, namespace, iss, sub, localpart)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ssoStatements) DeleteSSO(
|
||||||
|
ctx context.Context, txn *sql.Tx, namespace, iss, sub string) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteSSOStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, namespace, iss, sub)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -52,6 +52,7 @@ type Database struct {
|
||||||
LoginTokens tables.LoginTokenTable
|
LoginTokens tables.LoginTokenTable
|
||||||
Notifications tables.NotificationTable
|
Notifications tables.NotificationTable
|
||||||
Pushers tables.PusherTable
|
Pushers tables.PusherTable
|
||||||
|
SSOs tables.SSOTable
|
||||||
Stats tables.StatsTable
|
Stats tables.StatsTable
|
||||||
LoginTokenLifetime time.Duration
|
LoginTokenLifetime time.Duration
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
|
@ -225,6 +226,35 @@ func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||||
return string(hashBytes), err
|
return string(hashBytes), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrSSOInUse = errors.New("this SSO account is already in use")
|
||||||
|
|
||||||
|
func (d *Database) SaveSSOAssociation(ctx context.Context, namespace, iss, sub, localpart string) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
user, err := d.SSOs.SelectLocalpartForSSO(
|
||||||
|
ctx, txn, namespace, iss, sub,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(user) > 0 {
|
||||||
|
return Err3PIDInUse
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.SSOs.InsertSSO(ctx, txn, namespace, iss, sub, localpart)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) RemoveSSOAssociation(ctx context.Context, namespace, iss, sub string) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.SSOs.DeleteSSO(ctx, txn, namespace, iss, sub)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetLocalpartForSSO(ctx context.Context, namespace, iss, sub string) (string, error) {
|
||||||
|
return d.SSOs.SelectLocalpartForSSO(ctx, nil, namespace, iss, sub)
|
||||||
|
}
|
||||||
|
|
||||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||||
// a third-party identifier which is already associated to a local user.
|
// a third-party identifier which is already associated to a local user.
|
||||||
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
||||||
|
|
|
||||||
93
userapi/storage/sqlite3/sso_table.go
Normal file
93
userapi/storage/sqlite3/sso_table.go
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ssoSchema = `
|
||||||
|
-- Stores data about SSO associations.
|
||||||
|
CREATE TABLE IF NOT EXISTS account_sso (
|
||||||
|
-- The "iss" namespace. Must be "oidc".
|
||||||
|
namespace TEXT NOT NULL,
|
||||||
|
-- The issuer; for "oidc", a URL.
|
||||||
|
iss TEXT NOT NULL,
|
||||||
|
-- The subject (user ID).
|
||||||
|
sub TEXT NOT NULL,
|
||||||
|
-- The localpart of the Matrix user ID associated to this 3PID
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY(namespace, iss, sub)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectLocalpartForSSOSQL = "" +
|
||||||
|
"SELECT localpart FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
|
||||||
|
|
||||||
|
const insertSSOSQL = "" +
|
||||||
|
"INSERT INTO account_sso (namespace, iss, sub, localpart) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
|
const deleteSSOSQL = "" +
|
||||||
|
"DELETE FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
|
||||||
|
|
||||||
|
type ssoStatements struct {
|
||||||
|
selectLocalpartForSSOStmt *sql.Stmt
|
||||||
|
insertSSOStmt *sql.Stmt
|
||||||
|
deleteSSOStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSQLiteSSOTable(db *sql.DB) (tables.SSOTable, error) {
|
||||||
|
s := &ssoStatements{}
|
||||||
|
_, err := db.Exec(ssoSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.selectLocalpartForSSOStmt, selectLocalpartForSSOSQL},
|
||||||
|
{&s.insertSSOStmt, insertSSOSQL},
|
||||||
|
{&s.deleteSSOStmt, deleteSSOSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ssoStatements) SelectLocalpartForSSO(
|
||||||
|
ctx context.Context, txn *sql.Tx, namespace, iss, sub string,
|
||||||
|
) (localpart string, err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForSSOStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, namespace, iss, sub).Scan(&localpart)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ssoStatements) InsertSSO(
|
||||||
|
ctx context.Context, txn *sql.Tx, namespace, iss, sub, localpart string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertSSOStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, namespace, iss, sub, localpart)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ssoStatements) DeleteSSO(
|
||||||
|
ctx context.Context, txn *sql.Tx, namespace, iss, sub string) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteSSOStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, namespace, iss, sub)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -113,6 +113,12 @@ type NotificationTable interface {
|
||||||
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
|
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SSOTable interface {
|
||||||
|
SelectLocalpartForSSO(ctx context.Context, txn *sql.Tx, namespace, iss, sub string) (string, error)
|
||||||
|
InsertSSO(ctx context.Context, txn *sql.Tx, namespace, iss, sub, localpart string) error
|
||||||
|
DeleteSSO(ctx context.Context, txn *sql.Tx, namespace, iss, sub string) error
|
||||||
|
}
|
||||||
|
|
||||||
type StatsTable interface {
|
type StatsTable interface {
|
||||||
UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error)
|
UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error)
|
||||||
UpdateUserDailyVisits(ctx context.Context, txn *sql.Tx, startTime, lastUpdate time.Time) error
|
UpdateUserDailyVisits(ctx context.Context, txn *sql.Tx, startTime, lastUpdate time.Time) error
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue