Add SSO UserAPI endpoints.

This is mostly copied from the ThirdPID, but with a primary key that
matches OpenID Connect nomenclature. There's a namspace to ensure
other SSO solutions can be supported, but there's only one namespace
defined for now.
This commit is contained in:
Tommie Gannert 2022-05-23 17:38:30 +02:00
parent c9ad7206c8
commit c3f7945284
12 changed files with 492 additions and 0 deletions

View file

@ -77,6 +77,7 @@ type SyncUserAPI interface {
type ClientUserAPI interface {
QueryAcccessTokenAPI
LoginTokenInternalAPI
SSOAPI
UserLoginAPI
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error

53
userapi/api/api_sso.go Normal file
View file

@ -0,0 +1,53 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
)
type SSOAPI interface {
QueryLocalpartForSSO(ctx context.Context, req *QueryLocalpartForSSORequest, res *QueryLocalpartForSSOResponse) error
PerformForgetSSO(ctx context.Context, req *PerformForgetSSORequest, res *struct{}) error
PerformSaveSSOAssociation(ctx context.Context, req *PerformSaveSSOAssociationRequest, res *struct{}) error
}
type QueryLocalpartForSSORequest struct {
Namespace SSOIssuerNamespace
Issuer, Subject string
}
type QueryLocalpartForSSOResponse struct {
Localpart string
}
type PerformForgetSSORequest QueryLocalpartForSSORequest
type PerformSaveSSOAssociationRequest struct {
Namespace SSOIssuerNamespace
Issuer, Subject string
Localpart string
}
// An SSOIssuerNamespace defines the interpretation of an issuer.
type SSOIssuerNamespace string
const (
UnknownNamespace SSOIssuerNamespace = ""
// OIDCNamespace indicates the issuer is a full URL, as defined in
// https://openid.net/specs/openid-connect-core-1_0.html#Terminology.
OIDCNamespace SSOIssuerNamespace = "oidc"
)

View file

@ -0,0 +1,39 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"github.com/matrix-org/util"
)
func (t *UserInternalAPITrace) QueryLocalpartForSSO(ctx context.Context, req *QueryLocalpartForSSORequest, res *QueryLocalpartForSSOResponse) error {
err := t.Impl.QueryLocalpartForSSO(ctx, req, res)
util.GetLogger(ctx).Infof("QueryLocalpartForSSO req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformForgetSSO(ctx context.Context, req *PerformForgetSSORequest, res *struct{}) error {
err := t.Impl.PerformForgetSSO(ctx, req, res)
util.GetLogger(ctx).Infof("PerformForgetSSO req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformSaveSSOAssociation(ctx context.Context, req *PerformSaveSSOAssociationRequest, res *struct{}) error {
err := t.Impl.PerformSaveSSOAssociation(ctx, req, res)
util.GetLogger(ctx).Infof("PerformSaveSSOAssociation req=%+v res=%+v", js(req), js(res))
return err
}

View file

@ -0,0 +1,50 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/userapi/api"
)
func (a *UserInternalAPI) QueryLocalpartForSSO(ctx context.Context, req *api.QueryLocalpartForSSORequest, res *api.QueryLocalpartForSSOResponse) error {
var err error
res.Localpart, err = a.DB.GetLocalpartForSSO(ctx, string(req.Namespace), req.Issuer, req.Subject)
return err
}
func (a *UserInternalAPI) PerformForgetSSO(ctx context.Context, req *api.PerformForgetSSORequest, res *struct{}) error {
return a.DB.RemoveSSOAssociation(ctx, string(req.Namespace), req.Issuer, req.Subject)
}
func (a *UserInternalAPI) PerformSaveSSOAssociation(ctx context.Context, req *api.PerformSaveSSOAssociationRequest, res *struct{}) error {
ns, err := validateSSOIssuerNamespace(req.Namespace)
if err != nil {
return err
}
return a.DB.SaveSSOAssociation(ctx, ns, req.Issuer, req.Subject, req.Localpart)
}
func validateSSOIssuerNamespace(ns api.SSOIssuerNamespace) (string, error) {
switch ns {
case api.OIDCNamespace:
return string(ns), nil
default:
return "", fmt.Errorf("invalid SSO issuer namespace: %s", ns)
}
}

View file

@ -0,0 +1,53 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package inthttp
import (
"context"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
const (
PerformForgetSSOPath = "/userapi/performForgetSSO"
PerformSaveSSOAssociationPath = "/userapi/performSaveSSOAssociation"
QueryLocalpartForSSOPath = "/userapi/queryLocalpartForSSO"
)
func (h *httpUserInternalAPI) QueryLocalpartForSSO(ctx context.Context, req *api.QueryLocalpartForSSORequest, res *api.QueryLocalpartForSSOResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForSSOPath)
defer span.Finish()
apiURL := h.apiURL + QueryLocalpartForSSOPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformForgetSSO(ctx context.Context, req *api.PerformForgetSSORequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetSSOPath)
defer span.Finish()
apiURL := h.apiURL + PerformForgetSSOPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformSaveSSOAssociation(ctx context.Context, req *api.PerformSaveSSOAssociationRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveSSOAssociationPath)
defer span.Finish()
apiURL := h.apiURL + PerformSaveSSOAssociationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -28,6 +28,7 @@ import (
// nolint: gocyclo
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
addRoutesLoginToken(internalAPIMux, s)
addRoutesSSO(internalAPIMux, s)
internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {

View file

@ -0,0 +1,66 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
// addRoutesSSO adds routes for all SSO API calls.
func addRoutesSSO(internalAPIMux *mux.Router, s api.UserInternalAPI) {
internalAPIMux.Handle(QueryLocalpartForSSOPath,
httputil.MakeInternalAPI("queryLocalpartForSSO", func(req *http.Request) util.JSONResponse {
request := api.QueryLocalpartForSSORequest{}
response := api.QueryLocalpartForSSOResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryLocalpartForSSO(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformForgetSSOPath,
httputil.MakeInternalAPI("performForgetSSO", func(req *http.Request) util.JSONResponse {
request := api.PerformForgetSSORequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformForgetSSO(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
)
internalAPIMux.Handle(PerformSaveSSOAssociationPath,
httputil.MakeInternalAPI("performSaveSSOAssociation", func(req *http.Request) util.JSONResponse {
request := api.PerformSaveSSOAssociationRequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformSaveSSOAssociation(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
)
}

View file

@ -109,6 +109,12 @@ type Pusher interface {
RemovePushers(ctx context.Context, appid, pushkey string) error
}
type SSO interface {
SaveSSOAssociation(ctx context.Context, namespace, iss, sub, localpart string) error
RemoveSSOAssociation(ctx context.Context, namespace, iss, sub string) error
GetLocalpartForSSO(ctx context.Context, namespace, iss, sub string) (string, error)
}
type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
@ -136,6 +142,7 @@ type Database interface {
OpenID
Profile
Pusher
SSO
Statistics
ThreePID
}

View file

@ -0,0 +1,93 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const ssoSchema = `
-- Stores data about SSO associations.
CREATE TABLE IF NOT EXISTS account_sso (
-- The "iss" namespace. Must be "oidc".
namespace TEXT NOT NULL,
-- The issuer; for "oidc", a URL.
iss TEXT NOT NULL,
-- The subject (user ID).
sub TEXT NOT NULL,
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
PRIMARY KEY(namespace, iss, sub)
);
`
const selectLocalpartForSSOSQL = "" +
"SELECT localpart FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
const insertSSOSQL = "" +
"INSERT INTO account_sso (namespace, iss, sub, localpart) VALUES ($1, $2, $3, $4)"
const deleteSSOSQL = "" +
"DELETE FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
type ssoStatements struct {
selectLocalpartForSSOStmt *sql.Stmt
insertSSOStmt *sql.Stmt
deleteSSOStmt *sql.Stmt
}
func NewPostgresSSOTable(db *sql.DB) (tables.SSOTable, error) {
s := &ssoStatements{}
_, err := db.Exec(ssoSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectLocalpartForSSOStmt, selectLocalpartForSSOSQL},
{&s.insertSSOStmt, insertSSOSQL},
{&s.deleteSSOStmt, deleteSSOSQL},
}.Prepare(db)
}
func (s *ssoStatements) SelectLocalpartForSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub string,
) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForSSOStmt)
err = stmt.QueryRowContext(ctx, namespace, iss, sub).Scan(&localpart)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *ssoStatements) InsertSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub, localpart string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertSSOStmt)
_, err = stmt.ExecContext(ctx, namespace, iss, sub, localpart)
return
}
func (s *ssoStatements) DeleteSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteSSOStmt)
_, err = stmt.ExecContext(ctx, namespace, iss, sub)
return
}

View file

@ -52,6 +52,7 @@ type Database struct {
LoginTokens tables.LoginTokenTable
Notifications tables.NotificationTable
Pushers tables.PusherTable
SSOs tables.SSOTable
Stats tables.StatsTable
LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName
@ -225,6 +226,35 @@ func (d *Database) hashPassword(plaintext string) (hash string, err error) {
return string(hashBytes), err
}
var ErrSSOInUse = errors.New("this SSO account is already in use")
func (d *Database) SaveSSOAssociation(ctx context.Context, namespace, iss, sub, localpart string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
user, err := d.SSOs.SelectLocalpartForSSO(
ctx, txn, namespace, iss, sub,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.SSOs.InsertSSO(ctx, txn, namespace, iss, sub, localpart)
})
}
func (d *Database) RemoveSSOAssociation(ctx context.Context, namespace, iss, sub string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.SSOs.DeleteSSO(ctx, txn, namespace, iss, sub)
})
}
func (d *Database) GetLocalpartForSSO(ctx context.Context, namespace, iss, sub string) (string, error) {
return d.SSOs.SelectLocalpartForSSO(ctx, nil, namespace, iss, sub)
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use")

View file

@ -0,0 +1,93 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const ssoSchema = `
-- Stores data about SSO associations.
CREATE TABLE IF NOT EXISTS account_sso (
-- The "iss" namespace. Must be "oidc".
namespace TEXT NOT NULL,
-- The issuer; for "oidc", a URL.
iss TEXT NOT NULL,
-- The subject (user ID).
sub TEXT NOT NULL,
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
PRIMARY KEY(namespace, iss, sub)
);
`
const selectLocalpartForSSOSQL = "" +
"SELECT localpart FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
const insertSSOSQL = "" +
"INSERT INTO account_sso (namespace, iss, sub, localpart) VALUES ($1, $2, $3, $4)"
const deleteSSOSQL = "" +
"DELETE FROM account_sso WHERE namespace = $1 AND iss = $2 AND sub = $3"
type ssoStatements struct {
selectLocalpartForSSOStmt *sql.Stmt
insertSSOStmt *sql.Stmt
deleteSSOStmt *sql.Stmt
}
func NewSQLiteSSOTable(db *sql.DB) (tables.SSOTable, error) {
s := &ssoStatements{}
_, err := db.Exec(ssoSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectLocalpartForSSOStmt, selectLocalpartForSSOSQL},
{&s.insertSSOStmt, insertSSOSQL},
{&s.deleteSSOStmt, deleteSSOSQL},
}.Prepare(db)
}
func (s *ssoStatements) SelectLocalpartForSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub string,
) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForSSOStmt)
err = stmt.QueryRowContext(ctx, namespace, iss, sub).Scan(&localpart)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *ssoStatements) InsertSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub, localpart string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertSSOStmt)
_, err = stmt.ExecContext(ctx, namespace, iss, sub, localpart)
return
}
func (s *ssoStatements) DeleteSSO(
ctx context.Context, txn *sql.Tx, namespace, iss, sub string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteSSOStmt)
_, err = stmt.ExecContext(ctx, namespace, iss, sub)
return
}

View file

@ -113,6 +113,12 @@ type NotificationTable interface {
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
}
type SSOTable interface {
SelectLocalpartForSSO(ctx context.Context, txn *sql.Tx, namespace, iss, sub string) (string, error)
InsertSSO(ctx context.Context, txn *sql.Tx, namespace, iss, sub, localpart string) error
DeleteSSO(ctx context.Context, txn *sql.Tx, namespace, iss, sub string) error
}
type StatsTable interface {
UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error)
UpdateUserDailyVisits(ctx context.Context, txn *sql.Tx, startTime, lastUpdate time.Time) error