Add SSO UserAPI endpoints.

This is mostly copied from the ThirdPID, but with a primary key that
matches OpenID Connect nomenclature. There's a namspace to ensure
other SSO solutions can be supported, but there's only one namespace
defined for now.
This commit is contained in:
Tommie Gannert 2022-05-23 17:38:30 +02:00
parent c9ad7206c8
commit c3f7945284
12 changed files with 492 additions and 0 deletions

View file

@ -77,6 +77,7 @@ type SyncUserAPI interface {
type ClientUserAPI interface { type ClientUserAPI interface {
QueryAcccessTokenAPI QueryAcccessTokenAPI
LoginTokenInternalAPI LoginTokenInternalAPI
SSOAPI
UserLoginAPI UserLoginAPI
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error

53
userapi/api/api_sso.go Normal file
View file

@ -0,0 +1,53 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
)
type SSOAPI interface {
QueryLocalpartForSSO(ctx context.Context, req *QueryLocalpartForSSORequest, res *QueryLocalpartForSSOResponse) error
PerformForgetSSO(ctx context.Context, req *PerformForgetSSORequest, res *struct{}) error
PerformSaveSSOAssociation(ctx context.Context, req *PerformSaveSSOAssociationRequest, res *struct{}) error
}
type QueryLocalpartForSSORequest struct {
Namespace SSOIssuerNamespace
Issuer, Subject string
}
type QueryLocalpartForSSOResponse struct {
Localpart string
}
type PerformForgetSSORequest QueryLocalpartForSSORequest
type PerformSaveSSOAssociationRequest struct {
Namespace SSOIssuerNamespace
Issuer, Subject string
Localpart string
}
// An SSOIssuerNamespace defines the interpretation of an issuer.
type SSOIssuerNamespace string
const (
UnknownNamespace SSOIssuerNamespace = ""
// OIDCNamespace indicates the issuer is a full URL, as defined in
// https://openid.net/specs/openid-connect-core-1_0.html#Terminology.
OIDCNamespace SSOIssuerNamespace = "oidc"
)

View file

@ -0,0 +1,39 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"github.com/matrix-org/util"
)
func (t *UserInternalAPITrace) QueryLocalpartForSSO(ctx context.Context, req *QueryLocalpartForSSORequest, res *QueryLocalpartForSSOResponse) error {
err := t.Impl.QueryLocalpartForSSO(ctx, req, res)
util.GetLogger(ctx).Infof("QueryLocalpartForSSO req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformForgetSSO(ctx context.Context, req *PerformForgetSSORequest, res *struct{}) error {
err := t.Impl.PerformForgetSSO(ctx, req, res)
util.GetLogger(ctx).Infof("PerformForgetSSO req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformSaveSSOAssociation(ctx context.Context, req *PerformSaveSSOAssociationRequest, res *struct{}) error {
err := t.Impl.PerformSaveSSOAssociation(ctx, req, res)
util.GetLogger(ctx).Infof("PerformSaveSSOAssociation req=%+v res=%+v", js(req), js(res))
return err
}

View file

@ -0,0 +1,50 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/userapi/api"
)
func (a *UserInternalAPI) QueryLocalpartForSSO(ctx context.Context, req *api.QueryLocalpartForSSORequest, res *api.QueryLocalpartForSSOResponse) error {
var err error
res.Localpart, err = a.DB.GetLocalpartForSSO(ctx, string(req.Namespace), req.Issuer, req.Subject)
return err
}
func (a *UserInternalAPI) PerformForgetSSO(ctx context.Context, req *api.PerformForgetSSORequest, res *struct{}) error {
return a.DB.RemoveSSOAssociation(ctx, string(req.Namespace), req.Issuer, req.Subject)
}
func (a *UserInternalAPI) PerformSaveSSOAssociation(ctx context.Context, req *api.PerformSaveSSOAssociationRequest, res *struct{}) error {
ns, err := validateSSOIssuerNamespace(req.Namespace)
if err != nil {
return err
}
return a.DB.SaveSSOAssociation(ctx, ns, req.Issuer, req.Subject, req.Localpart)
}
func validateSSOIssuerNamespace(ns api.SSOIssuerNamespace) (string, error) {
switch ns {
case api.OIDCNamespace:
return string(ns), nil
default:
return "", fmt.Errorf("invalid SSO issuer namespace: %s", ns)
}
}

View file

@ -0,0 +1,53 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package inthttp
import (
"context"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
const (
PerformForgetSSOPath = "/userapi/performForgetSSO"
PerformSaveSSOAssociationPath = "/userapi/performSaveSSOAssociation"
QueryLocalpartForSSOPath = "/userapi/queryLocalpartForSSO"
)
func (h *httpUserInternalAPI) QueryLocalpartForSSO(ctx context.Context, req *api.QueryLocalpartForSSORequest, res *api.QueryLocalpartForSSOResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForSSOPath)
defer span.Finish()
apiURL := h.apiURL + QueryLocalpartForSSOPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformForgetSSO(ctx context.Context, req *api.PerformForgetSSORequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetSSOPath)
defer span.Finish()
apiURL := h.apiURL + PerformForgetSSOPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformSaveSSOAssociation(ctx context.Context, req *api.PerformSaveSSOAssociationRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveSSOAssociationPath)
defer span.Finish()
apiURL := h.apiURL + PerformSaveSSOAssociationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -28,6 +28,7 @@ import (
// nolint: gocyclo // nolint: gocyclo
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
addRoutesLoginToken(internalAPIMux, s) addRoutesLoginToken(internalAPIMux, s)
addRoutesSSO(internalAPIMux, s)
internalAPIMux.Handle(PerformAccountCreationPath, internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {

View file

@ -0,0 +1,66 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
// addRoutesSSO adds routes for all SSO API calls.
func addRoutesSSO(internalAPIMux *mux.Router, s api.UserInternalAPI) {
internalAPIMux.Handle(QueryLocalpartForSSOPath,
httputil.MakeInternalAPI("queryLocalpartForSSO", func(req *http.Request) util.JSONResponse {
request := api.QueryLocalpartForSSORequest{}
response := api.QueryLocalpartForSSOResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryLocalpartForSSO(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformForgetSSOPath,
httputil.MakeInternalAPI("performForgetSSO", func(req *http.Request) util.JSONResponse {
request := api.PerformForgetSSORequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformForgetSSO(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
)
internalAPIMux.Handle(PerformSaveSSOAssociationPath,
httputil.MakeInternalAPI("performSaveSSOAssociation", func(req *http.Request) util.JSONResponse {
request := api.PerformSaveSSOAssociationRequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformSaveSSOAssociation(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
)
}

View file

@ -109,6 +109,12 @@ type Pusher interface {
RemovePushers(ctx context.Context, appid, pushkey string) error RemovePushers(ctx context.Context, appid, pushkey string) error
} }
type SSO interface {
SaveSSOAssociation(ctx context.Context, namespace, iss, sub, localpart string) error
RemoveSSOAssociation(ctx context.Context, namespace, iss, sub string) error
GetLocalpartForSSO(ctx context.Context, namespace, iss, sub string) (string, error)
}
type ThreePID interface { type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
@ -136,6 +142,7 @@ type Database interface {
OpenID OpenID
Profile Profile
Pusher Pusher
SSO
Statistics Statistics
ThreePID ThreePID
} }

View file

@ -0,0 +1,93 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const ssoSchema = `
-- Stores data about SSO associations.
CREATE TABLE IF NOT EXISTS account_sso (
-- The "iss" namespace. Must be "oidc".
namespace TEXT NOT NULL,
-- The issuer; for "oidc", a URL.
iss TEXT NOT NULL,
-- The subject (user ID).
sub TEXT NOT NULL,
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
PRIMARY KEY(namespace, iss, sub)
);
`
const selectLocalpartForSSOSQL = "" +
"SELECT localpart FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
const insertSSOSQL = "" +
"INSERT INTO account_sso (namespace, iss, sub, localpart) VALUES ($1, $2, $3, $4)"
const deleteSSOSQL = "" +
"DELETE FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
type ssoStatements struct {
selectLocalpartForSSOStmt *sql.Stmt
insertSSOStmt *sql.Stmt
deleteSSOStmt *sql.Stmt
}
func NewPostgresSSOTable(db *sql.DB) (tables.SSOTable, error) {
s := &ssoStatements{}
_, err := db.Exec(ssoSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectLocalpartForSSOStmt, selectLocalpartForSSOSQL},
{&s.insertSSOStmt, insertSSOSQL},
{&s.deleteSSOStmt, deleteSSOSQL},
}.Prepare(db)
}
func (s *ssoStatements) SelectLocalpartForSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub string,
) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForSSOStmt)
err = stmt.QueryRowContext(ctx, namespace, iss, sub).Scan(&localpart)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *ssoStatements) InsertSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub, localpart string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertSSOStmt)
_, err = stmt.ExecContext(ctx, namespace, iss, sub, localpart)
return
}
func (s *ssoStatements) DeleteSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteSSOStmt)
_, err = stmt.ExecContext(ctx, namespace, iss, sub)
return
}

View file

@ -52,6 +52,7 @@ type Database struct {
LoginTokens tables.LoginTokenTable LoginTokens tables.LoginTokenTable
Notifications tables.NotificationTable Notifications tables.NotificationTable
Pushers tables.PusherTable Pushers tables.PusherTable
SSOs tables.SSOTable
Stats tables.StatsTable Stats tables.StatsTable
LoginTokenLifetime time.Duration LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
@ -225,6 +226,35 @@ func (d *Database) hashPassword(plaintext string) (hash string, err error) {
return string(hashBytes), err return string(hashBytes), err
} }
var ErrSSOInUse = errors.New("this SSO account is already in use")
func (d *Database) SaveSSOAssociation(ctx context.Context, namespace, iss, sub, localpart string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
user, err := d.SSOs.SelectLocalpartForSSO(
ctx, txn, namespace, iss, sub,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.SSOs.InsertSSO(ctx, txn, namespace, iss, sub, localpart)
})
}
func (d *Database) RemoveSSOAssociation(ctx context.Context, namespace, iss, sub string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.SSOs.DeleteSSO(ctx, txn, namespace, iss, sub)
})
}
func (d *Database) GetLocalpartForSSO(ctx context.Context, namespace, iss, sub string) (string, error) {
return d.SSOs.SelectLocalpartForSSO(ctx, nil, namespace, iss, sub)
}
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user. // a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use") var Err3PIDInUse = errors.New("this third-party identifier is already in use")

View file

@ -0,0 +1,93 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const ssoSchema = `
-- Stores data about SSO associations.
CREATE TABLE IF NOT EXISTS account_sso (
-- The "iss" namespace. Must be "oidc".
namespace TEXT NOT NULL,
-- The issuer; for "oidc", a URL.
iss TEXT NOT NULL,
-- The subject (user ID).
sub TEXT NOT NULL,
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
PRIMARY KEY(namespace, iss, sub)
);
`
const selectLocalpartForSSOSQL = "" +
"SELECT localpart FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
const insertSSOSQL = "" +
"INSERT INTO account_sso (namespace, iss, sub, localpart) VALUES ($1, $2, $3, $4)"
const deleteSSOSQL = "" +
"DELETE FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
type ssoStatements struct {
selectLocalpartForSSOStmt *sql.Stmt
insertSSOStmt *sql.Stmt
deleteSSOStmt *sql.Stmt
}
func NewSQLiteSSOTable(db *sql.DB) (tables.SSOTable, error) {
s := &ssoStatements{}
_, err := db.Exec(ssoSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectLocalpartForSSOStmt, selectLocalpartForSSOSQL},
{&s.insertSSOStmt, insertSSOSQL},
{&s.deleteSSOStmt, deleteSSOSQL},
}.Prepare(db)
}
func (s *ssoStatements) SelectLocalpartForSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub string,
) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForSSOStmt)
err = stmt.QueryRowContext(ctx, namespace, iss, sub).Scan(&localpart)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *ssoStatements) InsertSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub, localpart string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertSSOStmt)
_, err = stmt.ExecContext(ctx, namespace, iss, sub, localpart)
return
}
func (s *ssoStatements) DeleteSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteSSOStmt)
_, err = stmt.ExecContext(ctx, namespace, iss, sub)
return
}

View file

@ -113,6 +113,12 @@ type NotificationTable interface {
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
} }
type SSOTable interface {
SelectLocalpartForSSO(ctx context.Context, txn *sql.Tx, namespace, iss, sub string) (string, error)
InsertSSO(ctx context.Context, txn *sql.Tx, namespace, iss, sub, localpart string) error
DeleteSSO(ctx context.Context, txn *sql.Tx, namespace, iss, sub string) error
}
type StatsTable interface { type StatsTable interface {
UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error) UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error)
UpdateUserDailyVisits(ctx context.Context, txn *sql.Tx, startTime, lastUpdate time.Time) error UpdateUserDailyVisits(ctx context.Context, txn *sql.Tx, startTime, lastUpdate time.Time) error