mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-29 09:43:10 -06:00
add digital signature challenge response authentication mechanism
using ed25519 keypairs Signed-off-by: Fabian Deifuß <deifussfabian@icloud.com>
This commit is contained in:
parent
a47b12dc7d
commit
f887bcea6f
|
|
@ -7,6 +7,7 @@ type LoginType string
|
|||
const (
|
||||
LoginTypePassword = "m.login.password"
|
||||
LoginTypeDummy = "m.login.dummy"
|
||||
LoginTypeChallengeResponse = "m.login.challenge_response"
|
||||
LoginTypeSharedSecret = "org.matrix.login.shared_secret"
|
||||
LoginTypeRecaptcha = "m.login.recaptcha"
|
||||
LoginTypeApplicationService = "m.login.application_service"
|
||||
|
|
|
|||
67
clientapi/auth/challenge_response.go
Normal file
67
clientapi/auth/challenge_response.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type GetAccountByChallengeResponse func(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error)
|
||||
|
||||
type ChallengeResponseRequest struct {
|
||||
Login
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
|
||||
// LoginTypeChallengeResponse using public key encryption
|
||||
type LoginTypeChallengeResponse struct {
|
||||
GetAccountByChallengeResponse GetAccountByChallengeResponse
|
||||
Config *config.ClientAPI
|
||||
}
|
||||
|
||||
func (t *LoginTypeChallengeResponse) Name() string {
|
||||
return authtypes.LoginTypeChallengeResponse
|
||||
}
|
||||
|
||||
func (t *LoginTypeChallengeResponse) Request() interface{} {
|
||||
return &ChallengeResponseRequest{}
|
||||
}
|
||||
|
||||
func (t *LoginTypeChallengeResponse) Login(ctx context.Context, req interface{}, challenge string) (*Login, *util.JSONResponse) {
|
||||
r := req.(*ChallengeResponseRequest)
|
||||
username := r.Username()
|
||||
if username == "" {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("A username must be supplied."),
|
||||
}
|
||||
}
|
||||
localpart, err := userutil.ParseUsernameParam(username, &t.Config.Matrix.ServerName)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||
}
|
||||
}
|
||||
if r.Signature == "" {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidSignature("No signature provided"),
|
||||
}
|
||||
}
|
||||
_, err = t.GetAccountByChallengeResponse(ctx, localpart, r.Signature, challenge)
|
||||
if err != nil {
|
||||
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
||||
// but that would leak the existence of the user.
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("The digital signature is incorrect or the account does not exist."),
|
||||
}
|
||||
}
|
||||
return &r.Login, nil
|
||||
}
|
||||
|
|
@ -47,7 +47,7 @@ func (t *LoginTypePassword) Request() interface{} {
|
|||
return &PasswordRequest{}
|
||||
}
|
||||
|
||||
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
|
||||
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}, challenge string) (*Login, *util.JSONResponse) {
|
||||
r := req.(*PasswordRequest)
|
||||
// Squash username to all lowercase letters
|
||||
username := strings.ToLower(r.Username())
|
||||
|
|
|
|||
|
|
@ -17,14 +17,15 @@ package auth
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"fmt"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Type represents an auth type
|
||||
|
|
@ -43,7 +44,7 @@ type Type interface {
|
|||
// "If the homeserver decides that an attempt on a stage was unsuccessful, but the
|
||||
// client may make a second attempt, it returns the same HTTP status 401 response as above,
|
||||
// with the addition of the standard errcode and error fields describing the error."
|
||||
Login(ctx context.Context, req interface{}) (login *Login, errRes *util.JSONResponse)
|
||||
Login(ctx context.Context, req interface{}, challenge string) (login *Login, errRes *util.JSONResponse)
|
||||
// TODO: Extend to support Register() flow
|
||||
// Register(ctx context.Context, sessionID string, req interface{})
|
||||
}
|
||||
|
|
@ -108,14 +109,19 @@ type UserInteractive struct {
|
|||
// Map of login type to implementation
|
||||
Types map[string]Type
|
||||
// Map of session ID to completed login types, will need to be extended in future
|
||||
Sessions map[string][]string
|
||||
Sessions map[string][]string
|
||||
SessionParams map[string]map[string]string
|
||||
}
|
||||
|
||||
func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI) *UserInteractive {
|
||||
func NewUserInteractive(getAccByPass GetAccountByPassword, getAccByChallengeResponse GetAccountByChallengeResponse, cfg *config.ClientAPI) *UserInteractive {
|
||||
typePassword := &LoginTypePassword{
|
||||
GetAccountByPassword: getAccByPass,
|
||||
Config: cfg,
|
||||
}
|
||||
typeChallengeResponse := &LoginTypeChallengeResponse{
|
||||
GetAccountByChallengeResponse: getAccByChallengeResponse,
|
||||
Config: cfg,
|
||||
}
|
||||
// TODO: Add SSO login
|
||||
return &UserInteractive{
|
||||
Completed: []string{},
|
||||
|
|
@ -123,11 +129,16 @@ func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI
|
|||
{
|
||||
Stages: []string{typePassword.Name()},
|
||||
},
|
||||
{
|
||||
Stages: []string{typeChallengeResponse.Name()},
|
||||
},
|
||||
},
|
||||
Types: map[string]Type{
|
||||
typePassword.Name(): typePassword,
|
||||
typePassword.Name(): typePassword,
|
||||
typeChallengeResponse.Name(): typeChallengeResponse,
|
||||
},
|
||||
Sessions: make(map[string][]string),
|
||||
Sessions: make(map[string][]string),
|
||||
SessionParams: make(map[string]map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -148,6 +159,9 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) {
|
|||
|
||||
// Challenge returns an HTTP 401 with the supported flows for authenticating
|
||||
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
||||
params := make(map[string]string)
|
||||
params["challenge"] = fmt.Sprintf("%d%s", time.Now().Unix(), sessionID)
|
||||
u.SessionParams[sessionID] = params
|
||||
return &util.JSONResponse{
|
||||
Code: 401,
|
||||
JSON: struct {
|
||||
|
|
@ -155,18 +169,18 @@ func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
|||
Flows []userInteractiveFlow `json:"flows"`
|
||||
Session string `json:"session"`
|
||||
// TODO: Return any additional `params`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
Params map[string]string `json:"params"`
|
||||
}{
|
||||
u.Completed,
|
||||
u.Flows,
|
||||
sessionID,
|
||||
make(map[string]interface{}),
|
||||
params,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewSession returns a challenge with a new session ID and remembers the session ID
|
||||
func (u *UserInteractive) NewSession() *util.JSONResponse {
|
||||
func (u *UserInteractive) NewSession(authType string) *util.JSONResponse {
|
||||
sessionID, err := GenerateAccessToken()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("failed to generate session ID")
|
||||
|
|
@ -207,15 +221,16 @@ func (u *UserInteractive) ResponseWithChallenge(sessionID string, response inter
|
|||
func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *api.Device) (*Login, *util.JSONResponse) {
|
||||
// TODO: rate limit
|
||||
|
||||
// extract the type so we know which login type to use
|
||||
authType := gjson.GetBytes(bodyBytes, "auth.type").Str
|
||||
|
||||
// "A client should first make a request with no auth parameter. The homeserver returns an HTTP 401 response, with a JSON body"
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#user-interactive-api-in-the-rest-api
|
||||
hasResponse := gjson.GetBytes(bodyBytes, "auth").Exists()
|
||||
if !hasResponse {
|
||||
return nil, u.NewSession()
|
||||
return nil, u.NewSession(authType)
|
||||
}
|
||||
|
||||
// extract the type so we know which login type to use
|
||||
authType := gjson.GetBytes(bodyBytes, "auth.type").Str
|
||||
loginType, ok := u.Types[authType]
|
||||
if !ok {
|
||||
return nil, &util.JSONResponse{
|
||||
|
|
@ -237,13 +252,13 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *
|
|||
}
|
||||
|
||||
r := loginType.Request()
|
||||
if err := json.Unmarshal([]byte(gjson.GetBytes(bodyBytes, "auth").Raw), r); err != nil {
|
||||
if err := json.Unmarshal(bodyBytes, r); err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
||||
}
|
||||
}
|
||||
login, resErr := loginType.Login(ctx, r)
|
||||
login, resErr := loginType.Login(ctx, r, u.SessionParams[sessionID]["challenge"])
|
||||
if resErr == nil {
|
||||
u.AddCompletedStage(sessionID, authType)
|
||||
// TODO: Check if there's more stages to go and return an error
|
||||
|
|
|
|||
|
|
@ -32,13 +32,21 @@ func getAccountByPassword(ctx context.Context, localpart, plaintextPassword stri
|
|||
return acc, nil
|
||||
}
|
||||
|
||||
func getAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error) {
|
||||
acc, ok := lookup[localpart+" "+b64encodedSignature+" "+challenge]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown user/pubkey")
|
||||
}
|
||||
return acc, nil
|
||||
}
|
||||
|
||||
func setup() *UserInteractive {
|
||||
cfg := &config.ClientAPI{
|
||||
Matrix: &config.Global{
|
||||
ServerName: serverName,
|
||||
},
|
||||
}
|
||||
return NewUserInteractive(getAccountByPassword, cfg)
|
||||
return NewUserInteractive(getAccountByPassword, getAccountByChallengeResponse, cfg)
|
||||
}
|
||||
|
||||
func TestUserInteractiveChallenge(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"unicode/utf8"
|
||||
|
|
@ -54,3 +56,19 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetLoginType(req *http.Request) (string, *util.JSONResponse) {
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
resp := jsonerror.InternalServerError()
|
||||
return "", &resp
|
||||
}
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||
r := &auth.Login{}
|
||||
jsonErr := UnmarshalJSONRequest(req, r)
|
||||
if jsonErr != nil {
|
||||
return "", jsonErr
|
||||
}
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||
return r.Type, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ func UploadCrossSigningDeviceKeys(
|
|||
GetAccountByPassword: accountDB.GetAccountByPassword,
|
||||
Config: cfg,
|
||||
}
|
||||
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
|
||||
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest, ""); authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ package routing
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||
|
|
@ -44,42 +47,70 @@ type flow struct {
|
|||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
func passwordLogin() flows {
|
||||
func enabledLoginTypes(cfg *config.ClientAPI) flows {
|
||||
f := flows{}
|
||||
s := flow{
|
||||
Type: "m.login.password",
|
||||
for _, loginType := range cfg.LoginTypes {
|
||||
s := flow{
|
||||
Type: loginType,
|
||||
}
|
||||
f.Flows = append(f.Flows, s)
|
||||
}
|
||||
f.Flows = append(f.Flows, s)
|
||||
return f
|
||||
}
|
||||
|
||||
// Login implements GET and POST /login
|
||||
func Login(
|
||||
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
||||
cfg *config.ClientAPI,
|
||||
cfg *config.ClientAPI, userInteractiveAuth *auth.UserInteractive,
|
||||
) util.JSONResponse {
|
||||
if req.Method == http.MethodGet {
|
||||
// TODO: support other forms of login other than password, depending on config options
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: passwordLogin(),
|
||||
JSON: enabledLoginTypes(cfg),
|
||||
}
|
||||
} else if req.Method == http.MethodPost {
|
||||
typePassword := auth.LoginTypePassword{
|
||||
GetAccountByPassword: accountDB.GetAccountByPassword,
|
||||
Config: cfg,
|
||||
loginType, err := httputil.GetLoginType(req)
|
||||
if err != nil {
|
||||
return *err
|
||||
}
|
||||
r := typePassword.Request()
|
||||
resErr := httputil.UnmarshalJSONRequest(req, r)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
switch loginType {
|
||||
case authtypes.LoginTypePassword:
|
||||
typePassword := auth.LoginTypePassword{
|
||||
GetAccountByPassword: accountDB.GetAccountByPassword,
|
||||
Config: cfg,
|
||||
}
|
||||
r := typePassword.Request()
|
||||
resErr := httputil.UnmarshalJSONRequest(req, r)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
login, authErr := typePassword.Login(req.Context(), r, "")
|
||||
if authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
// make a device/access token
|
||||
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
|
||||
case authtypes.LoginTypeChallengeResponse:
|
||||
defer req.Body.Close() // nolint:errcheck
|
||||
bodyBytes, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()),
|
||||
}
|
||||
}
|
||||
login, resErr := userInteractiveAuth.Verify(req.Context(), bodyBytes, nil)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
// create a device/access token
|
||||
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
|
||||
default:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidParam(fmt.Sprintf("Unsupported login type: %s", loginType)),
|
||||
}
|
||||
}
|
||||
login, authErr := typePassword.Login(req.Context(), r)
|
||||
if authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
// make a device/access token
|
||||
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusMethodNotAllowed,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ func Password(
|
|||
GetAccountByPassword: accountDB.GetAccountByPassword,
|
||||
Config: cfg,
|
||||
}
|
||||
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
|
||||
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest, ""); authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
|
|
|
|||
|
|
@ -121,9 +121,10 @@ var (
|
|||
// previous parameters with the ones supplied. This mean you cannot "build up" request params.
|
||||
type registerRequest struct {
|
||||
// registration parameters
|
||||
Password string `json:"password"`
|
||||
Username string `json:"username"`
|
||||
Admin bool `json:"admin"`
|
||||
Password string `json:"password"`
|
||||
Username string `json:"username"`
|
||||
B64encodedPublicKey string `json:"b64PublicKey"`
|
||||
Admin bool `json:"admin"`
|
||||
// user-interactive auth params
|
||||
Auth authDict `json:"auth"`
|
||||
|
||||
|
|
@ -517,9 +518,10 @@ func Register(
|
|||
|
||||
logger := util.GetLogger(req.Context())
|
||||
logger.WithFields(log.Fields{
|
||||
"username": r.Username,
|
||||
"auth.type": r.Auth.Type,
|
||||
"session_id": r.Auth.Session,
|
||||
"username": r.Username,
|
||||
"b64PublicKey": r.B64encodedPublicKey,
|
||||
"auth.type": r.Auth.Type,
|
||||
"session_id": r.Auth.Session,
|
||||
}).Info("Processing registration request")
|
||||
|
||||
return handleRegistrationFlow(req, r, sessionID, cfg, userAPI, accessToken, accessTokenErr)
|
||||
|
|
@ -700,7 +702,7 @@ func handleApplicationServiceRegistration(
|
|||
// Don't need to worry about appending to registration stages as
|
||||
// application service registration is entirely separate.
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(),
|
||||
req.Context(), userAPI, r.Username, "", "", appserviceID, req.RemoteAddr, req.UserAgent(),
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
||||
)
|
||||
}
|
||||
|
|
@ -719,7 +721,7 @@ func checkAndCompleteFlow(
|
|||
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
||||
// This flow was completed, registration can continue
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(),
|
||||
req.Context(), userAPI, r.Username, r.Password, r.B64encodedPublicKey, "", req.RemoteAddr, req.UserAgent(),
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
||||
)
|
||||
}
|
||||
|
|
@ -739,13 +741,7 @@ func checkAndCompleteFlow(
|
|||
// registerRequest, as this function serves requests encoded as both
|
||||
// registerRequests and legacyRegisterRequests, which share some attributes but
|
||||
// not all
|
||||
func completeRegistration(
|
||||
ctx context.Context,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
username, password, appserviceID, ipAddr, userAgent string,
|
||||
inhibitLogin eventutil.WeakBoolean,
|
||||
displayName, deviceID *string,
|
||||
) util.JSONResponse {
|
||||
func completeRegistration(ctx context.Context, userAPI userapi.UserInternalAPI, username, password, b64encodedPublicKey, appserviceID, ipAddr, userAgent string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string) util.JSONResponse {
|
||||
if username == "" {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
@ -762,11 +758,12 @@ func completeRegistration(
|
|||
|
||||
var accRes userapi.PerformAccountCreationResponse
|
||||
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
|
||||
AppServiceID: appserviceID,
|
||||
Localpart: username,
|
||||
Password: password,
|
||||
AccountType: userapi.AccountTypeUser,
|
||||
OnConflict: userapi.ConflictAbort,
|
||||
AppServiceID: appserviceID,
|
||||
Localpart: username,
|
||||
Password: password,
|
||||
B64encodedPublicKey: b64encodedPublicKey,
|
||||
AccountType: userapi.AccountTypeUser,
|
||||
OnConflict: userapi.ConflictAbort,
|
||||
}, &accRes)
|
||||
if err != nil {
|
||||
if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists
|
||||
|
|
@ -963,5 +960,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
|
|||
return *resErr
|
||||
}
|
||||
deviceID := "shared_secret_registration"
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID)
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ func Setup(
|
|||
mscCfg *config.MSCs,
|
||||
) {
|
||||
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
|
||||
userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg)
|
||||
userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, accountDB.GetAccountByChallengeResponse, cfg)
|
||||
|
||||
unstableFeatures := map[string]bool{
|
||||
"org.matrix.e2e_cross_signing": true,
|
||||
|
|
@ -498,7 +498,7 @@ func Setup(
|
|||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
}
|
||||
return Login(req, accountDB, userAPI, cfg)
|
||||
return Login(req, accountDB, userAPI, cfg, userInteractiveAuth)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -39,6 +42,8 @@ Example:
|
|||
|
||||
# provide password by parameter
|
||||
%s --config dendrite.yaml -username alice -password foobarbaz
|
||||
# auto generate keypair
|
||||
%s --config dendrite.yaml -username alice -password foobarbaz -udk NuJ7J4BsaE8QZT1ULNTc3s8ZjLFmDPh91l1i0Urf/ls=
|
||||
# use password from file
|
||||
%s --config dendrite.yaml -username alice -passwordfile my.pass
|
||||
# ask user to provide password
|
||||
|
|
@ -52,17 +57,18 @@ Arguments:
|
|||
`
|
||||
|
||||
var (
|
||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
|
||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
|
||||
createKeypair = flag.Bool("create-keypair", false, "Whether to create an Ed25519 keypair for the account to create (optional)")
|
||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
||||
)
|
||||
|
||||
func main() {
|
||||
name := os.Args[0]
|
||||
flag.Usage = func() {
|
||||
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
|
||||
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
cfg := setup.ParseFlags(true)
|
||||
|
|
@ -81,7 +87,32 @@ func main() {
|
|||
logrus.Fatalln("Failed to connect to the database:", err.Error())
|
||||
}
|
||||
|
||||
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "")
|
||||
var pub64 string
|
||||
if *createKeypair {
|
||||
pub, priv, err2 := ed25519.GenerateKey(rand.Reader)
|
||||
pub64 = base64.StdEncoding.EncodeToString(priv.Public().(ed25519.PublicKey))
|
||||
if err2 != nil {
|
||||
logrus.Fatalln(err2)
|
||||
}
|
||||
err2 = os.WriteFile("private.key", priv, 0644)
|
||||
if err2 != nil {
|
||||
logrus.Fatalln(err2)
|
||||
}
|
||||
err2 = os.WriteFile("private.key.seed", priv.Seed(), 0644)
|
||||
if err2 != nil {
|
||||
logrus.Fatalln(err2)
|
||||
}
|
||||
err2 = os.WriteFile("./public.key", pub, 0644)
|
||||
if err2 != nil {
|
||||
logrus.Fatalln(err2)
|
||||
}
|
||||
err2 = os.WriteFile("./public.key.b64", []byte(pub64), 0644)
|
||||
if err2 != nil {
|
||||
logrus.Fatalln(err2)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = accountDB.CreateAccount(context.Background(), *username, pass, pub64, "")
|
||||
if err != nil {
|
||||
logrus.Fatalln("Failed to create the account:", err.Error())
|
||||
}
|
||||
|
|
|
|||
61
cmd/sign-challenge/main.go
Normal file
61
cmd/sign-challenge/main.go
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"os"
|
||||
)
|
||||
|
||||
const usage = `Usage: %s
|
||||
|
||||
Sign a string using an Ed25519 private key
|
||||
|
||||
Arguments:
|
||||
|
||||
`
|
||||
|
||||
var (
|
||||
privateKeyFile = flag.String("private-key", "", "An Ed25519 private key seed used to sign the input")
|
||||
input = flag.String("input", "", "The input to sign")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, usage, os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *privateKeyFile == "" || *input == "" {
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
if *privateKeyFile != "" && *input != "" {
|
||||
seedBytes, err := os.ReadFile(*privateKeyFile)
|
||||
if err != nil {
|
||||
logrus.Fatalln(err)
|
||||
}
|
||||
priv := ed25519.NewKeyFromSeed(seedBytes)
|
||||
sig := ed25519.Sign(priv, []byte(*input))
|
||||
logrus.Infoln("Signature: " + base64.StdEncoding.EncodeToString(sig))
|
||||
}
|
||||
}
|
||||
|
|
@ -145,6 +145,11 @@ client_api:
|
|||
external_api:
|
||||
listen: http://[::]:8071
|
||||
|
||||
# Which of the authentication flows to support on this server
|
||||
login_types:
|
||||
- m.login.password
|
||||
- m.login.challenge_response
|
||||
|
||||
# Prevents new users from being able to register on this homeserver, except when
|
||||
# using the registration shared secret below.
|
||||
registration_disabled: false
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ type ClientAPI struct {
|
|||
InternalAPI InternalAPIOptions `yaml:"internal_api"`
|
||||
ExternalAPI ExternalAPIOptions `yaml:"external_api"`
|
||||
|
||||
// What authentication mechanisms shall be supported
|
||||
// by this server
|
||||
LoginTypes []string `yaml:"login_types"`
|
||||
|
||||
// If set disables new users from registering (except via shared
|
||||
// secrets)
|
||||
RegistrationDisabled bool `yaml:"registration_disabled"`
|
||||
|
|
@ -45,6 +49,7 @@ func (c *ClientAPI) Defaults(generate bool) {
|
|||
c.InternalAPI.Listen = "http://localhost:7771"
|
||||
c.InternalAPI.Connect = "http://localhost:7771"
|
||||
c.ExternalAPI.Listen = "http://[::]:8071"
|
||||
c.LoginTypes = []string{"m.login.password"}
|
||||
c.RegistrationSharedSecret = ""
|
||||
c.RecaptchaPublicKey = ""
|
||||
c.RecaptchaPrivateKey = ""
|
||||
|
|
|
|||
|
|
@ -243,9 +243,10 @@ type PerformAccountCreationRequest struct {
|
|||
AccountType AccountType // Required: whether this is a guest or user account
|
||||
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
||||
|
||||
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
||||
Password string // optional: if missing then this account will be a passwordless account
|
||||
OnConflict Conflict
|
||||
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
||||
Password string // optional: if missing then this account will be a passwordless account
|
||||
B64encodedPublicKey string // optional: if missing then this account will not allow digital signature login until a public key is added
|
||||
OnConflict Conflict
|
||||
}
|
||||
|
||||
// PerformAccountCreationResponse is the response for PerformAccountCreation
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
res.Account = acc
|
||||
return nil
|
||||
}
|
||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID)
|
||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.B64encodedPublicKey, req.AppServiceID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
switch req.OnConflict {
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import (
|
|||
type Database interface {
|
||||
internal.PartitionStorer
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
GetAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error)
|
||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
|
||||
|
|
@ -34,7 +35,7 @@ type Database interface {
|
|||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error)
|
||||
CreateAccount(ctx context.Context, localpart, plaintextPassword, b64encodedPublicKey, appserviceID string) (*api.Account, error)
|
||||
CreateGuestAccount(ctx context.Context) (*api.Account, error)
|
||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
created_ts BIGINT NOT NULL,
|
||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
password_hash TEXT,
|
||||
-- The public key for this account, base64 encoded.
|
||||
b64_public_key TEXT,
|
||||
-- Identifies which application service this account belongs to, if any.
|
||||
appservice_id TEXT,
|
||||
-- If the account is currently active
|
||||
|
|
@ -48,7 +50,7 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
|||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, b64_public_key, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
|
@ -62,6 +64,9 @@ const selectAccountByLocalpartSQL = "" +
|
|||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||
|
||||
const selectb64PubKeySQL = "" +
|
||||
"SELECT b64_public_key FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT nextval('numeric_username_seq')"
|
||||
|
||||
|
|
@ -71,6 +76,7 @@ type accountsStatements struct {
|
|||
deactivateAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
selectb64PubKeyStmt *sql.Stmt
|
||||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
|
@ -88,6 +94,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
||||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||
{&s.selectb64PubKeyStmt, selectb64PubKeySQL},
|
||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
|
@ -95,17 +102,15 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) insertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
||||
) (*api.Account, error) {
|
||||
func (s *accountsStatements) insertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, b64PubKey, appserviceID string) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
if appserviceID == "" {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, b64PubKey, nil)
|
||||
} else {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, appserviceID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -140,6 +145,13 @@ func (s *accountsStatements) selectPasswordHash(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectb64PubKey(
|
||||
ctx context.Context, localpart string,
|
||||
) (b64PubKey string, err error) {
|
||||
err = s.selectb64PubKeyStmt.QueryRowContext(ctx, localpart).Scan(&b64PubKey)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*api.Account, error) {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ package postgres
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -120,6 +122,29 @@ func (d *Database) GetAccountByPassword(
|
|||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetAccountByChallengeResponse returns the account associated with the given localpart and public key
|
||||
// if the given signature can be traced back to the public key.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error) {
|
||||
b64PubKey, err := d.accounts.selectb64PubKey(ctx, localpart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pubKey, err := base64.StdEncoding.DecodeString(b64PubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sig, err := base64.StdEncoding.DecodeString(b64encodedSignature)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verified := ed25519.Verify(pubKey, []byte(challenge), sig)
|
||||
if !verified {
|
||||
return nil, errors.New("Authentication error: Invalid signature")
|
||||
}
|
||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||
func (d *Database) GetProfileByLocalpart(
|
||||
|
|
@ -165,7 +190,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
|||
return err
|
||||
}
|
||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
|
||||
return err
|
||||
})
|
||||
return acc, err
|
||||
|
|
@ -174,19 +199,15 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
|||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, sqlutil.ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||
) (acc *api.Account, err error) {
|
||||
func (d *Database) CreateAccount(ctx context.Context, localpart, plaintextPassword, b64encodedPublicKey, appserviceID string) (acc *api.Account, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, b64encodedPublicKey, appserviceID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
||||
) (*api.Account, error) {
|
||||
func (d *Database) createAccount(ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, publicKey, appserviceID string) (*api.Account, error) {
|
||||
var account *api.Account
|
||||
var err error
|
||||
// Generate a password hash if this is not a password-less user
|
||||
|
|
@ -197,7 +218,7 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, publicKey, appserviceID); err != nil {
|
||||
if sqlutil.IsUniqueConstraintViolationErr(err) {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
created_ts BIGINT NOT NULL,
|
||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
password_hash TEXT,
|
||||
-- The public key for this account, base64 encoded.
|
||||
b64_public_key TEXT,
|
||||
-- Identifies which application service this account belongs to, if any.
|
||||
appservice_id TEXT,
|
||||
-- If the account is currently active
|
||||
|
|
@ -46,7 +48,7 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, b64_public_key, appservice_id) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
|
@ -60,6 +62,9 @@ const selectAccountByLocalpartSQL = "" +
|
|||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
|
||||
const selectb64PubKeySQL = "" +
|
||||
"SELECT b64_public_key FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COUNT(localpart) FROM account_accounts"
|
||||
|
||||
|
|
@ -70,6 +75,7 @@ type accountsStatements struct {
|
|||
deactivateAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
selectb64PubKeyStmt *sql.Stmt
|
||||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
|
@ -88,6 +94,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
||||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||
{&s.selectb64PubKeyStmt, selectb64PubKeySQL},
|
||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
|
@ -95,17 +102,15 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) insertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
||||
) (*api.Account, error) {
|
||||
func (s *accountsStatements) insertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, b64PubKey, appserviceID string) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
|
||||
var err error
|
||||
if appserviceID == "" {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, b64PubKey, nil)
|
||||
} else {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, appserviceID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -140,6 +145,13 @@ func (s *accountsStatements) selectPasswordHash(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectb64PubKey(
|
||||
ctx context.Context, localpart string,
|
||||
) (b64PubKey string, err error) {
|
||||
err = s.selectb64PubKeyStmt.QueryRowContext(ctx, localpart).Scan(&b64PubKey)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*api.Account, error) {
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ CREATE TABLE account_accounts (
|
|||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
b64_public_key TEXT,
|
||||
appservice_id TEXT,
|
||||
is_deactivated BOOLEAN DEFAULT 0
|
||||
);
|
||||
|
|
@ -47,6 +48,7 @@ CREATE TABLE account_accounts (
|
|||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
b64_public_key TEXT,
|
||||
appservice_id TEXT
|
||||
);
|
||||
INSERT
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ package sqlite3
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -125,6 +127,29 @@ func (d *Database) GetAccountByPassword(
|
|||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetAccountByChallengeResponse returns the account associated with the given localpart and public key
|
||||
// if the given signature can be traced back to the public key.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error) {
|
||||
b64PubKey, err := d.accounts.selectb64PubKey(ctx, localpart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pubKey, err := base64.StdEncoding.DecodeString(b64PubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sig, err := base64.StdEncoding.DecodeString(b64encodedSignature)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verified := ed25519.Verify(pubKey, []byte(challenge), sig)
|
||||
if !verified {
|
||||
return nil, errors.New("Authentication error: Invalid signature")
|
||||
}
|
||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||
func (d *Database) GetProfileByLocalpart(
|
||||
|
|
@ -191,7 +216,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
|||
return err
|
||||
}
|
||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
|
||||
return err
|
||||
})
|
||||
return acc, err
|
||||
|
|
@ -200,9 +225,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
|||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||
) (acc *api.Account, err error) {
|
||||
func (d *Database) CreateAccount(ctx context.Context, localpart, plaintextPassword, b64encodedPublicKey, appserviceID string) (acc *api.Account, err error) {
|
||||
// Create one account at a time else we can get 'database is locked'.
|
||||
d.profilesMu.Lock()
|
||||
d.accountDatasMu.Lock()
|
||||
|
|
@ -211,7 +234,7 @@ func (d *Database) CreateAccount(
|
|||
defer d.accountDatasMu.Unlock()
|
||||
defer d.accountsMu.Unlock()
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, b64encodedPublicKey, appserviceID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
|
@ -219,9 +242,7 @@ func (d *Database) CreateAccount(
|
|||
|
||||
// WARNING! This function assumes that the relevant mutexes have already
|
||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
||||
) (*api.Account, error) {
|
||||
func (d *Database) createAccount(ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, publicKey, appserviceID string) (*api.Account, error) {
|
||||
var err error
|
||||
var account *api.Account
|
||||
// Generate a password hash if this is not a password-less user
|
||||
|
|
@ -232,7 +253,7 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, publicKey, appserviceID); err != nil {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ func TestQueryProfile(t *testing.T) {
|
|||
aliceAvatarURL := "mxc://example.com/alice"
|
||||
aliceDisplayName := "Alice"
|
||||
userAPI, accountDB := MustMakeInternalAPI(t)
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "")
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue