mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-31 18:53:10 -06:00
add digital signature challenge response authentication mechanism
using ed25519 keypairs Signed-off-by: Fabian Deifuß <deifussfabian@icloud.com>
This commit is contained in:
parent
a47b12dc7d
commit
f887bcea6f
|
|
@ -7,6 +7,7 @@ type LoginType string
|
||||||
const (
|
const (
|
||||||
LoginTypePassword = "m.login.password"
|
LoginTypePassword = "m.login.password"
|
||||||
LoginTypeDummy = "m.login.dummy"
|
LoginTypeDummy = "m.login.dummy"
|
||||||
|
LoginTypeChallengeResponse = "m.login.challenge_response"
|
||||||
LoginTypeSharedSecret = "org.matrix.login.shared_secret"
|
LoginTypeSharedSecret = "org.matrix.login.shared_secret"
|
||||||
LoginTypeRecaptcha = "m.login.recaptcha"
|
LoginTypeRecaptcha = "m.login.recaptcha"
|
||||||
LoginTypeApplicationService = "m.login.application_service"
|
LoginTypeApplicationService = "m.login.application_service"
|
||||||
|
|
|
||||||
67
clientapi/auth/challenge_response.go
Normal file
67
clientapi/auth/challenge_response.go
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GetAccountByChallengeResponse func(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error)
|
||||||
|
|
||||||
|
type ChallengeResponseRequest struct {
|
||||||
|
Login
|
||||||
|
Signature string `json:"signature"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginTypeChallengeResponse using public key encryption
|
||||||
|
type LoginTypeChallengeResponse struct {
|
||||||
|
GetAccountByChallengeResponse GetAccountByChallengeResponse
|
||||||
|
Config *config.ClientAPI
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *LoginTypeChallengeResponse) Name() string {
|
||||||
|
return authtypes.LoginTypeChallengeResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *LoginTypeChallengeResponse) Request() interface{} {
|
||||||
|
return &ChallengeResponseRequest{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *LoginTypeChallengeResponse) Login(ctx context.Context, req interface{}, challenge string) (*Login, *util.JSONResponse) {
|
||||||
|
r := req.(*ChallengeResponseRequest)
|
||||||
|
username := r.Username()
|
||||||
|
if username == "" {
|
||||||
|
return nil, &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON("A username must be supplied."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
localpart, err := userutil.ParseUsernameParam(username, &t.Config.Matrix.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Signature == "" {
|
||||||
|
return nil, &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidSignature("No signature provided"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = t.GetAccountByChallengeResponse(ctx, localpart, r.Signature, challenge)
|
||||||
|
if err != nil {
|
||||||
|
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
||||||
|
// but that would leak the existence of the user.
|
||||||
|
return nil, &util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden("The digital signature is incorrect or the account does not exist."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &r.Login, nil
|
||||||
|
}
|
||||||
|
|
@ -47,7 +47,7 @@ func (t *LoginTypePassword) Request() interface{} {
|
||||||
return &PasswordRequest{}
|
return &PasswordRequest{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
|
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}, challenge string) (*Login, *util.JSONResponse) {
|
||||||
r := req.(*PasswordRequest)
|
r := req.(*PasswordRequest)
|
||||||
// Squash username to all lowercase letters
|
// Squash username to all lowercase letters
|
||||||
username := strings.ToLower(r.Username())
|
username := strings.ToLower(r.Username())
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,15 @@ package auth
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Type represents an auth type
|
// Type represents an auth type
|
||||||
|
|
@ -43,7 +44,7 @@ type Type interface {
|
||||||
// "If the homeserver decides that an attempt on a stage was unsuccessful, but the
|
// "If the homeserver decides that an attempt on a stage was unsuccessful, but the
|
||||||
// client may make a second attempt, it returns the same HTTP status 401 response as above,
|
// client may make a second attempt, it returns the same HTTP status 401 response as above,
|
||||||
// with the addition of the standard errcode and error fields describing the error."
|
// with the addition of the standard errcode and error fields describing the error."
|
||||||
Login(ctx context.Context, req interface{}) (login *Login, errRes *util.JSONResponse)
|
Login(ctx context.Context, req interface{}, challenge string) (login *Login, errRes *util.JSONResponse)
|
||||||
// TODO: Extend to support Register() flow
|
// TODO: Extend to support Register() flow
|
||||||
// Register(ctx context.Context, sessionID string, req interface{})
|
// Register(ctx context.Context, sessionID string, req interface{})
|
||||||
}
|
}
|
||||||
|
|
@ -108,14 +109,19 @@ type UserInteractive struct {
|
||||||
// Map of login type to implementation
|
// Map of login type to implementation
|
||||||
Types map[string]Type
|
Types map[string]Type
|
||||||
// Map of session ID to completed login types, will need to be extended in future
|
// Map of session ID to completed login types, will need to be extended in future
|
||||||
Sessions map[string][]string
|
Sessions map[string][]string
|
||||||
|
SessionParams map[string]map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI) *UserInteractive {
|
func NewUserInteractive(getAccByPass GetAccountByPassword, getAccByChallengeResponse GetAccountByChallengeResponse, cfg *config.ClientAPI) *UserInteractive {
|
||||||
typePassword := &LoginTypePassword{
|
typePassword := &LoginTypePassword{
|
||||||
GetAccountByPassword: getAccByPass,
|
GetAccountByPassword: getAccByPass,
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
}
|
}
|
||||||
|
typeChallengeResponse := &LoginTypeChallengeResponse{
|
||||||
|
GetAccountByChallengeResponse: getAccByChallengeResponse,
|
||||||
|
Config: cfg,
|
||||||
|
}
|
||||||
// TODO: Add SSO login
|
// TODO: Add SSO login
|
||||||
return &UserInteractive{
|
return &UserInteractive{
|
||||||
Completed: []string{},
|
Completed: []string{},
|
||||||
|
|
@ -123,11 +129,16 @@ func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI
|
||||||
{
|
{
|
||||||
Stages: []string{typePassword.Name()},
|
Stages: []string{typePassword.Name()},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Stages: []string{typeChallengeResponse.Name()},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Types: map[string]Type{
|
Types: map[string]Type{
|
||||||
typePassword.Name(): typePassword,
|
typePassword.Name(): typePassword,
|
||||||
|
typeChallengeResponse.Name(): typeChallengeResponse,
|
||||||
},
|
},
|
||||||
Sessions: make(map[string][]string),
|
Sessions: make(map[string][]string),
|
||||||
|
SessionParams: make(map[string]map[string]string),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -148,6 +159,9 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) {
|
||||||
|
|
||||||
// Challenge returns an HTTP 401 with the supported flows for authenticating
|
// Challenge returns an HTTP 401 with the supported flows for authenticating
|
||||||
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
||||||
|
params := make(map[string]string)
|
||||||
|
params["challenge"] = fmt.Sprintf("%d%s", time.Now().Unix(), sessionID)
|
||||||
|
u.SessionParams[sessionID] = params
|
||||||
return &util.JSONResponse{
|
return &util.JSONResponse{
|
||||||
Code: 401,
|
Code: 401,
|
||||||
JSON: struct {
|
JSON: struct {
|
||||||
|
|
@ -155,18 +169,18 @@ func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
||||||
Flows []userInteractiveFlow `json:"flows"`
|
Flows []userInteractiveFlow `json:"flows"`
|
||||||
Session string `json:"session"`
|
Session string `json:"session"`
|
||||||
// TODO: Return any additional `params`
|
// TODO: Return any additional `params`
|
||||||
Params map[string]interface{} `json:"params"`
|
Params map[string]string `json:"params"`
|
||||||
}{
|
}{
|
||||||
u.Completed,
|
u.Completed,
|
||||||
u.Flows,
|
u.Flows,
|
||||||
sessionID,
|
sessionID,
|
||||||
make(map[string]interface{}),
|
params,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSession returns a challenge with a new session ID and remembers the session ID
|
// NewSession returns a challenge with a new session ID and remembers the session ID
|
||||||
func (u *UserInteractive) NewSession() *util.JSONResponse {
|
func (u *UserInteractive) NewSession(authType string) *util.JSONResponse {
|
||||||
sessionID, err := GenerateAccessToken()
|
sessionID, err := GenerateAccessToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("failed to generate session ID")
|
logrus.WithError(err).Error("failed to generate session ID")
|
||||||
|
|
@ -207,15 +221,16 @@ func (u *UserInteractive) ResponseWithChallenge(sessionID string, response inter
|
||||||
func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *api.Device) (*Login, *util.JSONResponse) {
|
func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *api.Device) (*Login, *util.JSONResponse) {
|
||||||
// TODO: rate limit
|
// TODO: rate limit
|
||||||
|
|
||||||
|
// extract the type so we know which login type to use
|
||||||
|
authType := gjson.GetBytes(bodyBytes, "auth.type").Str
|
||||||
|
|
||||||
// "A client should first make a request with no auth parameter. The homeserver returns an HTTP 401 response, with a JSON body"
|
// "A client should first make a request with no auth parameter. The homeserver returns an HTTP 401 response, with a JSON body"
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#user-interactive-api-in-the-rest-api
|
// https://matrix.org/docs/spec/client_server/r0.6.1#user-interactive-api-in-the-rest-api
|
||||||
hasResponse := gjson.GetBytes(bodyBytes, "auth").Exists()
|
hasResponse := gjson.GetBytes(bodyBytes, "auth").Exists()
|
||||||
if !hasResponse {
|
if !hasResponse {
|
||||||
return nil, u.NewSession()
|
return nil, u.NewSession(authType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract the type so we know which login type to use
|
|
||||||
authType := gjson.GetBytes(bodyBytes, "auth.type").Str
|
|
||||||
loginType, ok := u.Types[authType]
|
loginType, ok := u.Types[authType]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
|
|
@ -237,13 +252,13 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *
|
||||||
}
|
}
|
||||||
|
|
||||||
r := loginType.Request()
|
r := loginType.Request()
|
||||||
if err := json.Unmarshal([]byte(gjson.GetBytes(bodyBytes, "auth").Raw), r); err != nil {
|
if err := json.Unmarshal(bodyBytes, r); err != nil {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
login, resErr := loginType.Login(ctx, r)
|
login, resErr := loginType.Login(ctx, r, u.SessionParams[sessionID]["challenge"])
|
||||||
if resErr == nil {
|
if resErr == nil {
|
||||||
u.AddCompletedStage(sessionID, authType)
|
u.AddCompletedStage(sessionID, authType)
|
||||||
// TODO: Check if there's more stages to go and return an error
|
// TODO: Check if there's more stages to go and return an error
|
||||||
|
|
|
||||||
|
|
@ -32,13 +32,21 @@ func getAccountByPassword(ctx context.Context, localpart, plaintextPassword stri
|
||||||
return acc, nil
|
return acc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error) {
|
||||||
|
acc, ok := lookup[localpart+" "+b64encodedSignature+" "+challenge]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown user/pubkey")
|
||||||
|
}
|
||||||
|
return acc, nil
|
||||||
|
}
|
||||||
|
|
||||||
func setup() *UserInteractive {
|
func setup() *UserInteractive {
|
||||||
cfg := &config.ClientAPI{
|
cfg := &config.ClientAPI{
|
||||||
Matrix: &config.Global{
|
Matrix: &config.Global{
|
||||||
ServerName: serverName,
|
ServerName: serverName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return NewUserInteractive(getAccountByPassword, cfg)
|
return NewUserInteractive(getAccountByPassword, getAccountByChallengeResponse, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserInteractiveChallenge(t *testing.T) {
|
func TestUserInteractiveChallenge(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@
|
||||||
package httputil
|
package httputil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
@ -54,3 +56,19 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetLoginType(req *http.Request) (string, *util.JSONResponse) {
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
resp := jsonerror.InternalServerError()
|
||||||
|
return "", &resp
|
||||||
|
}
|
||||||
|
req.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||||
|
r := &auth.Login{}
|
||||||
|
jsonErr := UnmarshalJSONRequest(req, r)
|
||||||
|
if jsonErr != nil {
|
||||||
|
return "", jsonErr
|
||||||
|
}
|
||||||
|
req.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||||
|
return r.Type, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ func UploadCrossSigningDeviceKeys(
|
||||||
GetAccountByPassword: accountDB.GetAccountByPassword,
|
GetAccountByPassword: accountDB.GetAccountByPassword,
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
}
|
}
|
||||||
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
|
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest, ""); authErr != nil {
|
||||||
return *authErr
|
return *authErr
|
||||||
}
|
}
|
||||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@ package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||||
|
|
@ -44,42 +47,70 @@ type flow struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func passwordLogin() flows {
|
func enabledLoginTypes(cfg *config.ClientAPI) flows {
|
||||||
f := flows{}
|
f := flows{}
|
||||||
s := flow{
|
for _, loginType := range cfg.LoginTypes {
|
||||||
Type: "m.login.password",
|
s := flow{
|
||||||
|
Type: loginType,
|
||||||
|
}
|
||||||
|
f.Flows = append(f.Flows, s)
|
||||||
}
|
}
|
||||||
f.Flows = append(f.Flows, s)
|
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login implements GET and POST /login
|
// Login implements GET and POST /login
|
||||||
func Login(
|
func Login(
|
||||||
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI, userInteractiveAuth *auth.UserInteractive,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
// TODO: support other forms of login other than password, depending on config options
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: passwordLogin(),
|
JSON: enabledLoginTypes(cfg),
|
||||||
}
|
}
|
||||||
} else if req.Method == http.MethodPost {
|
} else if req.Method == http.MethodPost {
|
||||||
typePassword := auth.LoginTypePassword{
|
loginType, err := httputil.GetLoginType(req)
|
||||||
GetAccountByPassword: accountDB.GetAccountByPassword,
|
if err != nil {
|
||||||
Config: cfg,
|
return *err
|
||||||
}
|
}
|
||||||
r := typePassword.Request()
|
switch loginType {
|
||||||
resErr := httputil.UnmarshalJSONRequest(req, r)
|
case authtypes.LoginTypePassword:
|
||||||
if resErr != nil {
|
typePassword := auth.LoginTypePassword{
|
||||||
return *resErr
|
GetAccountByPassword: accountDB.GetAccountByPassword,
|
||||||
|
Config: cfg,
|
||||||
|
}
|
||||||
|
r := typePassword.Request()
|
||||||
|
resErr := httputil.UnmarshalJSONRequest(req, r)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
login, authErr := typePassword.Login(req.Context(), r, "")
|
||||||
|
if authErr != nil {
|
||||||
|
return *authErr
|
||||||
|
}
|
||||||
|
// make a device/access token
|
||||||
|
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
|
||||||
|
case authtypes.LoginTypeChallengeResponse:
|
||||||
|
defer req.Body.Close() // nolint:errcheck
|
||||||
|
bodyBytes, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
login, resErr := userInteractiveAuth.Verify(req.Context(), bodyBytes, nil)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
// create a device/access token
|
||||||
|
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
|
||||||
|
default:
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidParam(fmt.Sprintf("Unsupported login type: %s", loginType)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
login, authErr := typePassword.Login(req.Context(), r)
|
|
||||||
if authErr != nil {
|
|
||||||
return *authErr
|
|
||||||
}
|
|
||||||
// make a device/access token
|
|
||||||
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
|
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusMethodNotAllowed,
|
Code: http.StatusMethodNotAllowed,
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ func Password(
|
||||||
GetAccountByPassword: accountDB.GetAccountByPassword,
|
GetAccountByPassword: accountDB.GetAccountByPassword,
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
}
|
}
|
||||||
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
|
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest, ""); authErr != nil {
|
||||||
return *authErr
|
return *authErr
|
||||||
}
|
}
|
||||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||||
|
|
|
||||||
|
|
@ -121,9 +121,10 @@ var (
|
||||||
// previous parameters with the ones supplied. This mean you cannot "build up" request params.
|
// previous parameters with the ones supplied. This mean you cannot "build up" request params.
|
||||||
type registerRequest struct {
|
type registerRequest struct {
|
||||||
// registration parameters
|
// registration parameters
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Admin bool `json:"admin"`
|
B64encodedPublicKey string `json:"b64PublicKey"`
|
||||||
|
Admin bool `json:"admin"`
|
||||||
// user-interactive auth params
|
// user-interactive auth params
|
||||||
Auth authDict `json:"auth"`
|
Auth authDict `json:"auth"`
|
||||||
|
|
||||||
|
|
@ -517,9 +518,10 @@ func Register(
|
||||||
|
|
||||||
logger := util.GetLogger(req.Context())
|
logger := util.GetLogger(req.Context())
|
||||||
logger.WithFields(log.Fields{
|
logger.WithFields(log.Fields{
|
||||||
"username": r.Username,
|
"username": r.Username,
|
||||||
"auth.type": r.Auth.Type,
|
"b64PublicKey": r.B64encodedPublicKey,
|
||||||
"session_id": r.Auth.Session,
|
"auth.type": r.Auth.Type,
|
||||||
|
"session_id": r.Auth.Session,
|
||||||
}).Info("Processing registration request")
|
}).Info("Processing registration request")
|
||||||
|
|
||||||
return handleRegistrationFlow(req, r, sessionID, cfg, userAPI, accessToken, accessTokenErr)
|
return handleRegistrationFlow(req, r, sessionID, cfg, userAPI, accessToken, accessTokenErr)
|
||||||
|
|
@ -700,7 +702,7 @@ func handleApplicationServiceRegistration(
|
||||||
// Don't need to worry about appending to registration stages as
|
// Don't need to worry about appending to registration stages as
|
||||||
// application service registration is entirely separate.
|
// application service registration is entirely separate.
|
||||||
return completeRegistration(
|
return completeRegistration(
|
||||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(),
|
req.Context(), userAPI, r.Username, "", "", appserviceID, req.RemoteAddr, req.UserAgent(),
|
||||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -719,7 +721,7 @@ func checkAndCompleteFlow(
|
||||||
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
||||||
// This flow was completed, registration can continue
|
// This flow was completed, registration can continue
|
||||||
return completeRegistration(
|
return completeRegistration(
|
||||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(),
|
req.Context(), userAPI, r.Username, r.Password, r.B64encodedPublicKey, "", req.RemoteAddr, req.UserAgent(),
|
||||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -739,13 +741,7 @@ func checkAndCompleteFlow(
|
||||||
// registerRequest, as this function serves requests encoded as both
|
// registerRequest, as this function serves requests encoded as both
|
||||||
// registerRequests and legacyRegisterRequests, which share some attributes but
|
// registerRequests and legacyRegisterRequests, which share some attributes but
|
||||||
// not all
|
// not all
|
||||||
func completeRegistration(
|
func completeRegistration(ctx context.Context, userAPI userapi.UserInternalAPI, username, password, b64encodedPublicKey, appserviceID, ipAddr, userAgent string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string) util.JSONResponse {
|
||||||
ctx context.Context,
|
|
||||||
userAPI userapi.UserInternalAPI,
|
|
||||||
username, password, appserviceID, ipAddr, userAgent string,
|
|
||||||
inhibitLogin eventutil.WeakBoolean,
|
|
||||||
displayName, deviceID *string,
|
|
||||||
) util.JSONResponse {
|
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
|
|
@ -762,11 +758,12 @@ func completeRegistration(
|
||||||
|
|
||||||
var accRes userapi.PerformAccountCreationResponse
|
var accRes userapi.PerformAccountCreationResponse
|
||||||
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
|
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
|
||||||
AppServiceID: appserviceID,
|
AppServiceID: appserviceID,
|
||||||
Localpart: username,
|
Localpart: username,
|
||||||
Password: password,
|
Password: password,
|
||||||
AccountType: userapi.AccountTypeUser,
|
B64encodedPublicKey: b64encodedPublicKey,
|
||||||
OnConflict: userapi.ConflictAbort,
|
AccountType: userapi.AccountTypeUser,
|
||||||
|
OnConflict: userapi.ConflictAbort,
|
||||||
}, &accRes)
|
}, &accRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists
|
if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists
|
||||||
|
|
@ -963,5 +960,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
deviceID := "shared_secret_registration"
|
deviceID := "shared_secret_registration"
|
||||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID)
|
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ func Setup(
|
||||||
mscCfg *config.MSCs,
|
mscCfg *config.MSCs,
|
||||||
) {
|
) {
|
||||||
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
|
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
|
||||||
userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg)
|
userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, accountDB.GetAccountByChallengeResponse, cfg)
|
||||||
|
|
||||||
unstableFeatures := map[string]bool{
|
unstableFeatures := map[string]bool{
|
||||||
"org.matrix.e2e_cross_signing": true,
|
"org.matrix.e2e_cross_signing": true,
|
||||||
|
|
@ -498,7 +498,7 @@ func Setup(
|
||||||
if r := rateLimits.Limit(req); r != nil {
|
if r := rateLimits.Limit(req); r != nil {
|
||||||
return *r
|
return *r
|
||||||
}
|
}
|
||||||
return Login(req, accountDB, userAPI, cfg)
|
return Login(req, accountDB, userAPI, cfg, userInteractiveAuth)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
@ -39,6 +42,8 @@ Example:
|
||||||
|
|
||||||
# provide password by parameter
|
# provide password by parameter
|
||||||
%s --config dendrite.yaml -username alice -password foobarbaz
|
%s --config dendrite.yaml -username alice -password foobarbaz
|
||||||
|
# auto generate keypair
|
||||||
|
%s --config dendrite.yaml -username alice -password foobarbaz -udk NuJ7J4BsaE8QZT1ULNTc3s8ZjLFmDPh91l1i0Urf/ls=
|
||||||
# use password from file
|
# use password from file
|
||||||
%s --config dendrite.yaml -username alice -passwordfile my.pass
|
%s --config dendrite.yaml -username alice -passwordfile my.pass
|
||||||
# ask user to provide password
|
# ask user to provide password
|
||||||
|
|
@ -52,17 +57,18 @@ Arguments:
|
||||||
`
|
`
|
||||||
|
|
||||||
var (
|
var (
|
||||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||||
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
|
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
|
||||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
createKeypair = flag.Bool("create-keypair", false, "Whether to create an Ed25519 keypair for the account to create (optional)")
|
||||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||||
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||||
|
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
name := os.Args[0]
|
name := os.Args[0]
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
|
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name)
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
}
|
}
|
||||||
cfg := setup.ParseFlags(true)
|
cfg := setup.ParseFlags(true)
|
||||||
|
|
@ -81,7 +87,32 @@ func main() {
|
||||||
logrus.Fatalln("Failed to connect to the database:", err.Error())
|
logrus.Fatalln("Failed to connect to the database:", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "")
|
var pub64 string
|
||||||
|
if *createKeypair {
|
||||||
|
pub, priv, err2 := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pub64 = base64.StdEncoding.EncodeToString(priv.Public().(ed25519.PublicKey))
|
||||||
|
if err2 != nil {
|
||||||
|
logrus.Fatalln(err2)
|
||||||
|
}
|
||||||
|
err2 = os.WriteFile("private.key", priv, 0644)
|
||||||
|
if err2 != nil {
|
||||||
|
logrus.Fatalln(err2)
|
||||||
|
}
|
||||||
|
err2 = os.WriteFile("private.key.seed", priv.Seed(), 0644)
|
||||||
|
if err2 != nil {
|
||||||
|
logrus.Fatalln(err2)
|
||||||
|
}
|
||||||
|
err2 = os.WriteFile("./public.key", pub, 0644)
|
||||||
|
if err2 != nil {
|
||||||
|
logrus.Fatalln(err2)
|
||||||
|
}
|
||||||
|
err2 = os.WriteFile("./public.key.b64", []byte(pub64), 0644)
|
||||||
|
if err2 != nil {
|
||||||
|
logrus.Fatalln(err2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = accountDB.CreateAccount(context.Background(), *username, pass, pub64, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalln("Failed to create the account:", err.Error())
|
logrus.Fatalln("Failed to create the account:", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
61
cmd/sign-challenge/main.go
Normal file
61
cmd/sign-challenge/main.go
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"encoding/base64"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
const usage = `Usage: %s
|
||||||
|
|
||||||
|
Sign a string using an Ed25519 private key
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
|
var (
|
||||||
|
privateKeyFile = flag.String("private-key", "", "An Ed25519 private key seed used to sign the input")
|
||||||
|
input = flag.String("input", "", "The input to sign")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Usage = func() {
|
||||||
|
fmt.Fprintf(os.Stderr, usage, os.Args[0])
|
||||||
|
flag.PrintDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if *privateKeyFile == "" || *input == "" {
|
||||||
|
flag.Usage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if *privateKeyFile != "" && *input != "" {
|
||||||
|
seedBytes, err := os.ReadFile(*privateKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatalln(err)
|
||||||
|
}
|
||||||
|
priv := ed25519.NewKeyFromSeed(seedBytes)
|
||||||
|
sig := ed25519.Sign(priv, []byte(*input))
|
||||||
|
logrus.Infoln("Signature: " + base64.StdEncoding.EncodeToString(sig))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -145,6 +145,11 @@ client_api:
|
||||||
external_api:
|
external_api:
|
||||||
listen: http://[::]:8071
|
listen: http://[::]:8071
|
||||||
|
|
||||||
|
# Which of the authentication flows to support on this server
|
||||||
|
login_types:
|
||||||
|
- m.login.password
|
||||||
|
- m.login.challenge_response
|
||||||
|
|
||||||
# Prevents new users from being able to register on this homeserver, except when
|
# Prevents new users from being able to register on this homeserver, except when
|
||||||
# using the registration shared secret below.
|
# using the registration shared secret below.
|
||||||
registration_disabled: false
|
registration_disabled: false
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,10 @@ type ClientAPI struct {
|
||||||
InternalAPI InternalAPIOptions `yaml:"internal_api"`
|
InternalAPI InternalAPIOptions `yaml:"internal_api"`
|
||||||
ExternalAPI ExternalAPIOptions `yaml:"external_api"`
|
ExternalAPI ExternalAPIOptions `yaml:"external_api"`
|
||||||
|
|
||||||
|
// What authentication mechanisms shall be supported
|
||||||
|
// by this server
|
||||||
|
LoginTypes []string `yaml:"login_types"`
|
||||||
|
|
||||||
// If set disables new users from registering (except via shared
|
// If set disables new users from registering (except via shared
|
||||||
// secrets)
|
// secrets)
|
||||||
RegistrationDisabled bool `yaml:"registration_disabled"`
|
RegistrationDisabled bool `yaml:"registration_disabled"`
|
||||||
|
|
@ -45,6 +49,7 @@ func (c *ClientAPI) Defaults(generate bool) {
|
||||||
c.InternalAPI.Listen = "http://localhost:7771"
|
c.InternalAPI.Listen = "http://localhost:7771"
|
||||||
c.InternalAPI.Connect = "http://localhost:7771"
|
c.InternalAPI.Connect = "http://localhost:7771"
|
||||||
c.ExternalAPI.Listen = "http://[::]:8071"
|
c.ExternalAPI.Listen = "http://[::]:8071"
|
||||||
|
c.LoginTypes = []string{"m.login.password"}
|
||||||
c.RegistrationSharedSecret = ""
|
c.RegistrationSharedSecret = ""
|
||||||
c.RecaptchaPublicKey = ""
|
c.RecaptchaPublicKey = ""
|
||||||
c.RecaptchaPrivateKey = ""
|
c.RecaptchaPrivateKey = ""
|
||||||
|
|
|
||||||
|
|
@ -243,9 +243,10 @@ type PerformAccountCreationRequest struct {
|
||||||
AccountType AccountType // Required: whether this is a guest or user account
|
AccountType AccountType // Required: whether this is a guest or user account
|
||||||
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
||||||
|
|
||||||
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
||||||
Password string // optional: if missing then this account will be a passwordless account
|
Password string // optional: if missing then this account will be a passwordless account
|
||||||
OnConflict Conflict
|
B64encodedPublicKey string // optional: if missing then this account will not allow digital signature login until a public key is added
|
||||||
|
OnConflict Conflict
|
||||||
}
|
}
|
||||||
|
|
||||||
// PerformAccountCreationResponse is the response for PerformAccountCreation
|
// PerformAccountCreationResponse is the response for PerformAccountCreation
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
res.Account = acc
|
res.Account = acc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID)
|
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.B64encodedPublicKey, req.AppServiceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||||
switch req.OnConflict {
|
switch req.OnConflict {
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ import (
|
||||||
type Database interface {
|
type Database interface {
|
||||||
internal.PartitionStorer
|
internal.PartitionStorer
|
||||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||||
|
GetAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error)
|
||||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
||||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
|
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
|
||||||
|
|
@ -34,7 +35,7 @@ type Database interface {
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error)
|
CreateAccount(ctx context.Context, localpart, plaintextPassword, b64encodedPublicKey, appserviceID string) (*api.Account, error)
|
||||||
CreateGuestAccount(ctx context.Context) (*api.Account, error)
|
CreateGuestAccount(ctx context.Context) (*api.Account, error)
|
||||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||||
password_hash TEXT,
|
password_hash TEXT,
|
||||||
|
-- The public key for this account, base64 encoded.
|
||||||
|
b64_public_key TEXT,
|
||||||
-- Identifies which application service this account belongs to, if any.
|
-- Identifies which application service this account belongs to, if any.
|
||||||
appservice_id TEXT,
|
appservice_id TEXT,
|
||||||
-- If the account is currently active
|
-- If the account is currently active
|
||||||
|
|
@ -48,7 +50,7 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
"INSERT INTO account_accounts(localpart, created_ts, password_hash, b64_public_key, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||||
|
|
@ -62,6 +64,9 @@ const selectAccountByLocalpartSQL = "" +
|
||||||
const selectPasswordHashSQL = "" +
|
const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||||
|
|
||||||
|
const selectb64PubKeySQL = "" +
|
||||||
|
"SELECT b64_public_key FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT nextval('numeric_username_seq')"
|
"SELECT nextval('numeric_username_seq')"
|
||||||
|
|
||||||
|
|
@ -71,6 +76,7 @@ type accountsStatements struct {
|
||||||
deactivateAccountStmt *sql.Stmt
|
deactivateAccountStmt *sql.Stmt
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
|
selectb64PubKeyStmt *sql.Stmt
|
||||||
selectNewNumericLocalpartStmt *sql.Stmt
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
@ -88,6 +94,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
||||||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||||
|
{&s.selectb64PubKeyStmt, selectb64PubKeySQL},
|
||||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
@ -95,17 +102,15 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) insertAccount(
|
func (s *accountsStatements) insertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, b64PubKey, appserviceID string) (*api.Account, error) {
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if appserviceID == "" {
|
if appserviceID == "" {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, b64PubKey, nil)
|
||||||
} else {
|
} else {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, appserviceID)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -140,6 +145,13 @@ func (s *accountsStatements) selectPasswordHash(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountsStatements) selectb64PubKey(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (b64PubKey string, err error) {
|
||||||
|
err = s.selectb64PubKeyStmt.QueryRowContext(ctx, localpart).Scan(&b64PubKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectAccountByLocalpart(
|
func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,9 @@ package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -120,6 +122,29 @@ func (d *Database) GetAccountByPassword(
|
||||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountByChallengeResponse returns the account associated with the given localpart and public key
|
||||||
|
// if the given signature can be traced back to the public key.
|
||||||
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
|
func (d *Database) GetAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error) {
|
||||||
|
b64PubKey, err := d.accounts.selectb64PubKey(ctx, localpart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pubKey, err := base64.StdEncoding.DecodeString(b64PubKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sig, err := base64.StdEncoding.DecodeString(b64encodedSignature)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
verified := ed25519.Verify(pubKey, []byte(challenge), sig)
|
||||||
|
if !verified {
|
||||||
|
return nil, errors.New("Authentication error: Invalid signature")
|
||||||
|
}
|
||||||
|
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||||
func (d *Database) GetProfileByLocalpart(
|
func (d *Database) GetProfileByLocalpart(
|
||||||
|
|
@ -165,7 +190,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return acc, err
|
return acc, err
|
||||||
|
|
@ -174,19 +199,15 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, sqlutil.ErrUserExists.
|
// account already exists, it will return nil, sqlutil.ErrUserExists.
|
||||||
func (d *Database) CreateAccount(
|
func (d *Database) CreateAccount(ctx context.Context, localpart, plaintextPassword, b64encodedPublicKey, appserviceID string) (acc *api.Account, err error) {
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
|
||||||
) (acc *api.Account, err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, b64encodedPublicKey, appserviceID)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) createAccount(
|
func (d *Database) createAccount(ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, publicKey, appserviceID string) (*api.Account, error) {
|
||||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
var account *api.Account
|
var account *api.Account
|
||||||
var err error
|
var err error
|
||||||
// Generate a password hash if this is not a password-less user
|
// Generate a password hash if this is not a password-less user
|
||||||
|
|
@ -197,7 +218,7 @@ func (d *Database) createAccount(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, publicKey, appserviceID); err != nil {
|
||||||
if sqlutil.IsUniqueConstraintViolationErr(err) {
|
if sqlutil.IsUniqueConstraintViolationErr(err) {
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||||
password_hash TEXT,
|
password_hash TEXT,
|
||||||
|
-- The public key for this account, base64 encoded.
|
||||||
|
b64_public_key TEXT,
|
||||||
-- Identifies which application service this account belongs to, if any.
|
-- Identifies which application service this account belongs to, if any.
|
||||||
appservice_id TEXT,
|
appservice_id TEXT,
|
||||||
-- If the account is currently active
|
-- If the account is currently active
|
||||||
|
|
@ -46,7 +48,7 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
"INSERT INTO account_accounts(localpart, created_ts, password_hash, b64_public_key, appservice_id) VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||||
|
|
@ -60,6 +62,9 @@ const selectAccountByLocalpartSQL = "" +
|
||||||
const selectPasswordHashSQL = "" +
|
const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||||
|
|
||||||
|
const selectb64PubKeySQL = "" +
|
||||||
|
"SELECT b64_public_key FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COUNT(localpart) FROM account_accounts"
|
"SELECT COUNT(localpart) FROM account_accounts"
|
||||||
|
|
||||||
|
|
@ -70,6 +75,7 @@ type accountsStatements struct {
|
||||||
deactivateAccountStmt *sql.Stmt
|
deactivateAccountStmt *sql.Stmt
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
|
selectb64PubKeyStmt *sql.Stmt
|
||||||
selectNewNumericLocalpartStmt *sql.Stmt
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
@ -88,6 +94,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
||||||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||||
|
{&s.selectb64PubKeyStmt, selectb64PubKeySQL},
|
||||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
@ -95,17 +102,15 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) insertAccount(
|
func (s *accountsStatements) insertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, b64PubKey, appserviceID string) (*api.Account, error) {
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := s.insertAccountStmt
|
stmt := s.insertAccountStmt
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if appserviceID == "" {
|
if appserviceID == "" {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, b64PubKey, nil)
|
||||||
} else {
|
} else {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, appserviceID)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -140,6 +145,13 @@ func (s *accountsStatements) selectPasswordHash(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountsStatements) selectb64PubKey(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (b64PubKey string, err error) {
|
||||||
|
err = s.selectb64PubKeyStmt.QueryRowContext(ctx, localpart).Scan(&b64PubKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectAccountByLocalpart(
|
func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ CREATE TABLE account_accounts (
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
password_hash TEXT,
|
password_hash TEXT,
|
||||||
|
b64_public_key TEXT,
|
||||||
appservice_id TEXT,
|
appservice_id TEXT,
|
||||||
is_deactivated BOOLEAN DEFAULT 0
|
is_deactivated BOOLEAN DEFAULT 0
|
||||||
);
|
);
|
||||||
|
|
@ -47,6 +48,7 @@ CREATE TABLE account_accounts (
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
password_hash TEXT,
|
password_hash TEXT,
|
||||||
|
b64_public_key TEXT,
|
||||||
appservice_id TEXT
|
appservice_id TEXT
|
||||||
);
|
);
|
||||||
INSERT
|
INSERT
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,9 @@ package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -125,6 +127,29 @@ func (d *Database) GetAccountByPassword(
|
||||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountByChallengeResponse returns the account associated with the given localpart and public key
|
||||||
|
// if the given signature can be traced back to the public key.
|
||||||
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
|
func (d *Database) GetAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error) {
|
||||||
|
b64PubKey, err := d.accounts.selectb64PubKey(ctx, localpart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pubKey, err := base64.StdEncoding.DecodeString(b64PubKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sig, err := base64.StdEncoding.DecodeString(b64encodedSignature)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
verified := ed25519.Verify(pubKey, []byte(challenge), sig)
|
||||||
|
if !verified {
|
||||||
|
return nil, errors.New("Authentication error: Invalid signature")
|
||||||
|
}
|
||||||
|
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||||
func (d *Database) GetProfileByLocalpart(
|
func (d *Database) GetProfileByLocalpart(
|
||||||
|
|
@ -191,7 +216,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return acc, err
|
return acc, err
|
||||||
|
|
@ -200,9 +225,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
func (d *Database) CreateAccount(
|
func (d *Database) CreateAccount(ctx context.Context, localpart, plaintextPassword, b64encodedPublicKey, appserviceID string) (acc *api.Account, err error) {
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
|
||||||
) (acc *api.Account, err error) {
|
|
||||||
// Create one account at a time else we can get 'database is locked'.
|
// Create one account at a time else we can get 'database is locked'.
|
||||||
d.profilesMu.Lock()
|
d.profilesMu.Lock()
|
||||||
d.accountDatasMu.Lock()
|
d.accountDatasMu.Lock()
|
||||||
|
|
@ -211,7 +234,7 @@ func (d *Database) CreateAccount(
|
||||||
defer d.accountDatasMu.Unlock()
|
defer d.accountDatasMu.Unlock()
|
||||||
defer d.accountsMu.Unlock()
|
defer d.accountsMu.Unlock()
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, b64encodedPublicKey, appserviceID)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
@ -219,9 +242,7 @@ func (d *Database) CreateAccount(
|
||||||
|
|
||||||
// WARNING! This function assumes that the relevant mutexes have already
|
// WARNING! This function assumes that the relevant mutexes have already
|
||||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||||
func (d *Database) createAccount(
|
func (d *Database) createAccount(ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, publicKey, appserviceID string) (*api.Account, error) {
|
||||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
var err error
|
var err error
|
||||||
var account *api.Account
|
var account *api.Account
|
||||||
// Generate a password hash if this is not a password-less user
|
// Generate a password hash if this is not a password-less user
|
||||||
|
|
@ -232,7 +253,7 @@ func (d *Database) createAccount(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, publicKey, appserviceID); err != nil {
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ func TestQueryProfile(t *testing.T) {
|
||||||
aliceAvatarURL := "mxc://example.com/alice"
|
aliceAvatarURL := "mxc://example.com/alice"
|
||||||
aliceDisplayName := "Alice"
|
aliceDisplayName := "Alice"
|
||||||
userAPI, accountDB := MustMakeInternalAPI(t)
|
userAPI, accountDB := MustMakeInternalAPI(t)
|
||||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "")
|
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to make account: %s", err)
|
t.Fatalf("failed to make account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue