add digital signature challenge response authentication mechanism

using ed25519 keypairs

Signed-off-by: Fabian Deifuß <deifussfabian@icloud.com>
This commit is contained in:
Fabian Deifuß 2021-10-12 21:23:32 +02:00
parent a47b12dc7d
commit f887bcea6f
24 changed files with 414 additions and 105 deletions

View file

@ -7,6 +7,7 @@ type LoginType string
const ( const (
LoginTypePassword = "m.login.password" LoginTypePassword = "m.login.password"
LoginTypeDummy = "m.login.dummy" LoginTypeDummy = "m.login.dummy"
LoginTypeChallengeResponse = "m.login.challenge_response"
LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeSharedSecret = "org.matrix.login.shared_secret"
LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeRecaptcha = "m.login.recaptcha"
LoginTypeApplicationService = "m.login.application_service" LoginTypeApplicationService = "m.login.application_service"

View file

@ -0,0 +1,67 @@
package auth
import (
"context"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
"net/http"
)
type GetAccountByChallengeResponse func(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error)
type ChallengeResponseRequest struct {
Login
Signature string `json:"signature"`
}
// LoginTypeChallengeResponse using public key encryption
type LoginTypeChallengeResponse struct {
GetAccountByChallengeResponse GetAccountByChallengeResponse
Config *config.ClientAPI
}
func (t *LoginTypeChallengeResponse) Name() string {
return authtypes.LoginTypeChallengeResponse
}
func (t *LoginTypeChallengeResponse) Request() interface{} {
return &ChallengeResponseRequest{}
}
func (t *LoginTypeChallengeResponse) Login(ctx context.Context, req interface{}, challenge string) (*Login, *util.JSONResponse) {
r := req.(*ChallengeResponseRequest)
username := r.Username()
if username == "" {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("A username must be supplied."),
}
}
localpart, err := userutil.ParseUsernameParam(username, &t.Config.Matrix.ServerName)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(err.Error()),
}
}
if r.Signature == "" {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidSignature("No signature provided"),
}
}
_, err = t.GetAccountByChallengeResponse(ctx, localpart, r.Signature, challenge)
if err != nil {
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user.
return nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The digital signature is incorrect or the account does not exist."),
}
}
return &r.Login, nil
}

View file

@ -47,7 +47,7 @@ func (t *LoginTypePassword) Request() interface{} {
return &PasswordRequest{} return &PasswordRequest{}
} }
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { func (t *LoginTypePassword) Login(ctx context.Context, req interface{}, challenge string) (*Login, *util.JSONResponse) {
r := req.(*PasswordRequest) r := req.(*PasswordRequest)
// Squash username to all lowercase letters // Squash username to all lowercase letters
username := strings.ToLower(r.Username()) username := strings.ToLower(r.Username())

View file

@ -17,14 +17,15 @@ package auth
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"net/http" "fmt"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"net/http"
"time"
) )
// Type represents an auth type // Type represents an auth type
@ -43,7 +44,7 @@ type Type interface {
// "If the homeserver decides that an attempt on a stage was unsuccessful, but the // "If the homeserver decides that an attempt on a stage was unsuccessful, but the
// client may make a second attempt, it returns the same HTTP status 401 response as above, // client may make a second attempt, it returns the same HTTP status 401 response as above,
// with the addition of the standard errcode and error fields describing the error." // with the addition of the standard errcode and error fields describing the error."
Login(ctx context.Context, req interface{}) (login *Login, errRes *util.JSONResponse) Login(ctx context.Context, req interface{}, challenge string) (login *Login, errRes *util.JSONResponse)
// TODO: Extend to support Register() flow // TODO: Extend to support Register() flow
// Register(ctx context.Context, sessionID string, req interface{}) // Register(ctx context.Context, sessionID string, req interface{})
} }
@ -109,13 +110,18 @@ type UserInteractive struct {
Types map[string]Type Types map[string]Type
// Map of session ID to completed login types, will need to be extended in future // Map of session ID to completed login types, will need to be extended in future
Sessions map[string][]string Sessions map[string][]string
SessionParams map[string]map[string]string
} }
func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI) *UserInteractive { func NewUserInteractive(getAccByPass GetAccountByPassword, getAccByChallengeResponse GetAccountByChallengeResponse, cfg *config.ClientAPI) *UserInteractive {
typePassword := &LoginTypePassword{ typePassword := &LoginTypePassword{
GetAccountByPassword: getAccByPass, GetAccountByPassword: getAccByPass,
Config: cfg, Config: cfg,
} }
typeChallengeResponse := &LoginTypeChallengeResponse{
GetAccountByChallengeResponse: getAccByChallengeResponse,
Config: cfg,
}
// TODO: Add SSO login // TODO: Add SSO login
return &UserInteractive{ return &UserInteractive{
Completed: []string{}, Completed: []string{},
@ -123,11 +129,16 @@ func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI
{ {
Stages: []string{typePassword.Name()}, Stages: []string{typePassword.Name()},
}, },
{
Stages: []string{typeChallengeResponse.Name()},
},
}, },
Types: map[string]Type{ Types: map[string]Type{
typePassword.Name(): typePassword, typePassword.Name(): typePassword,
typeChallengeResponse.Name(): typeChallengeResponse,
}, },
Sessions: make(map[string][]string), Sessions: make(map[string][]string),
SessionParams: make(map[string]map[string]string),
} }
} }
@ -148,6 +159,9 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) {
// Challenge returns an HTTP 401 with the supported flows for authenticating // Challenge returns an HTTP 401 with the supported flows for authenticating
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse { func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
params := make(map[string]string)
params["challenge"] = fmt.Sprintf("%d%s", time.Now().Unix(), sessionID)
u.SessionParams[sessionID] = params
return &util.JSONResponse{ return &util.JSONResponse{
Code: 401, Code: 401,
JSON: struct { JSON: struct {
@ -155,18 +169,18 @@ func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
Flows []userInteractiveFlow `json:"flows"` Flows []userInteractiveFlow `json:"flows"`
Session string `json:"session"` Session string `json:"session"`
// TODO: Return any additional `params` // TODO: Return any additional `params`
Params map[string]interface{} `json:"params"` Params map[string]string `json:"params"`
}{ }{
u.Completed, u.Completed,
u.Flows, u.Flows,
sessionID, sessionID,
make(map[string]interface{}), params,
}, },
} }
} }
// NewSession returns a challenge with a new session ID and remembers the session ID // NewSession returns a challenge with a new session ID and remembers the session ID
func (u *UserInteractive) NewSession() *util.JSONResponse { func (u *UserInteractive) NewSession(authType string) *util.JSONResponse {
sessionID, err := GenerateAccessToken() sessionID, err := GenerateAccessToken()
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to generate session ID") logrus.WithError(err).Error("failed to generate session ID")
@ -207,15 +221,16 @@ func (u *UserInteractive) ResponseWithChallenge(sessionID string, response inter
func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *api.Device) (*Login, *util.JSONResponse) { func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *api.Device) (*Login, *util.JSONResponse) {
// TODO: rate limit // TODO: rate limit
// extract the type so we know which login type to use
authType := gjson.GetBytes(bodyBytes, "auth.type").Str
// "A client should first make a request with no auth parameter. The homeserver returns an HTTP 401 response, with a JSON body" // "A client should first make a request with no auth parameter. The homeserver returns an HTTP 401 response, with a JSON body"
// https://matrix.org/docs/spec/client_server/r0.6.1#user-interactive-api-in-the-rest-api // https://matrix.org/docs/spec/client_server/r0.6.1#user-interactive-api-in-the-rest-api
hasResponse := gjson.GetBytes(bodyBytes, "auth").Exists() hasResponse := gjson.GetBytes(bodyBytes, "auth").Exists()
if !hasResponse { if !hasResponse {
return nil, u.NewSession() return nil, u.NewSession(authType)
} }
// extract the type so we know which login type to use
authType := gjson.GetBytes(bodyBytes, "auth.type").Str
loginType, ok := u.Types[authType] loginType, ok := u.Types[authType]
if !ok { if !ok {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
@ -237,13 +252,13 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *
} }
r := loginType.Request() r := loginType.Request()
if err := json.Unmarshal([]byte(gjson.GetBytes(bodyBytes, "auth").Raw), r); err != nil { if err := json.Unmarshal(bodyBytes, r); err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
} }
} }
login, resErr := loginType.Login(ctx, r) login, resErr := loginType.Login(ctx, r, u.SessionParams[sessionID]["challenge"])
if resErr == nil { if resErr == nil {
u.AddCompletedStage(sessionID, authType) u.AddCompletedStage(sessionID, authType)
// TODO: Check if there's more stages to go and return an error // TODO: Check if there's more stages to go and return an error

View file

@ -32,13 +32,21 @@ func getAccountByPassword(ctx context.Context, localpart, plaintextPassword stri
return acc, nil return acc, nil
} }
func getAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error) {
acc, ok := lookup[localpart+" "+b64encodedSignature+" "+challenge]
if !ok {
return nil, fmt.Errorf("unknown user/pubkey")
}
return acc, nil
}
func setup() *UserInteractive { func setup() *UserInteractive {
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
ServerName: serverName, ServerName: serverName,
}, },
} }
return NewUserInteractive(getAccountByPassword, cfg) return NewUserInteractive(getAccountByPassword, getAccountByChallengeResponse, cfg)
} }
func TestUserInteractiveChallenge(t *testing.T) { func TestUserInteractiveChallenge(t *testing.T) {

View file

@ -15,7 +15,9 @@
package httputil package httputil
import ( import (
"bytes"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"unicode/utf8" "unicode/utf8"
@ -54,3 +56,19 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon
} }
return nil return nil
} }
func GetLoginType(req *http.Request) (string, *util.JSONResponse) {
body, err := ioutil.ReadAll(req.Body)
if err != nil {
resp := jsonerror.InternalServerError()
return "", &resp
}
req.Body = ioutil.NopCloser(bytes.NewReader(body))
r := &auth.Login{}
jsonErr := UnmarshalJSONRequest(req, r)
if jsonErr != nil {
return "", jsonErr
}
req.Body = ioutil.NopCloser(bytes.NewReader(body))
return r.Type, nil
}

View file

@ -67,7 +67,7 @@ func UploadCrossSigningDeviceKeys(
GetAccountByPassword: accountDB.GetAccountByPassword, GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg, Config: cfg,
} }
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest, ""); authErr != nil {
return *authErr return *authErr
} }
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)

View file

@ -16,6 +16,9 @@ package routing
import ( import (
"context" "context"
"fmt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"io/ioutil"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
@ -44,27 +47,34 @@ type flow struct {
Type string `json:"type"` Type string `json:"type"`
} }
func passwordLogin() flows { func enabledLoginTypes(cfg *config.ClientAPI) flows {
f := flows{} f := flows{}
for _, loginType := range cfg.LoginTypes {
s := flow{ s := flow{
Type: "m.login.password", Type: loginType,
} }
f.Flows = append(f.Flows, s) f.Flows = append(f.Flows, s)
}
return f return f
} }
// Login implements GET and POST /login // Login implements GET and POST /login
func Login( func Login(
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
cfg *config.ClientAPI, cfg *config.ClientAPI, userInteractiveAuth *auth.UserInteractive,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == http.MethodGet { if req.Method == http.MethodGet {
// TODO: support other forms of login other than password, depending on config options
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: passwordLogin(), JSON: enabledLoginTypes(cfg),
} }
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
loginType, err := httputil.GetLoginType(req)
if err != nil {
return *err
}
switch loginType {
case authtypes.LoginTypePassword:
typePassword := auth.LoginTypePassword{ typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword, GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg, Config: cfg,
@ -74,12 +84,33 @@ func Login(
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
login, authErr := typePassword.Login(req.Context(), r) login, authErr := typePassword.Login(req.Context(), r, "")
if authErr != nil { if authErr != nil {
return *authErr return *authErr
} }
// make a device/access token // make a device/access token
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
case authtypes.LoginTypeChallengeResponse:
defer req.Body.Close() // nolint:errcheck
bodyBytes, err := ioutil.ReadAll(req.Body)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()),
}
}
login, resErr := userInteractiveAuth.Verify(req.Context(), bodyBytes, nil)
if resErr != nil {
return *resErr
}
// create a device/access token
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
default:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidParam(fmt.Sprintf("Unsupported login type: %s", loginType)),
}
}
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusMethodNotAllowed, Code: http.StatusMethodNotAllowed,

View file

@ -71,7 +71,7 @@ func Password(
GetAccountByPassword: accountDB.GetAccountByPassword, GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg, Config: cfg,
} }
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest, ""); authErr != nil {
return *authErr return *authErr
} }
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)

View file

@ -123,6 +123,7 @@ type registerRequest struct {
// registration parameters // registration parameters
Password string `json:"password"` Password string `json:"password"`
Username string `json:"username"` Username string `json:"username"`
B64encodedPublicKey string `json:"b64PublicKey"`
Admin bool `json:"admin"` Admin bool `json:"admin"`
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
@ -518,6 +519,7 @@ func Register(
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
logger.WithFields(log.Fields{ logger.WithFields(log.Fields{
"username": r.Username, "username": r.Username,
"b64PublicKey": r.B64encodedPublicKey,
"auth.type": r.Auth.Type, "auth.type": r.Auth.Type,
"session_id": r.Auth.Session, "session_id": r.Auth.Session,
}).Info("Processing registration request") }).Info("Processing registration request")
@ -700,7 +702,7 @@ func handleApplicationServiceRegistration(
// Don't need to worry about appending to registration stages as // Don't need to worry about appending to registration stages as
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), req.Context(), userAPI, r.Username, "", "", appserviceID, req.RemoteAddr, req.UserAgent(),
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -719,7 +721,7 @@ func checkAndCompleteFlow(
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), req.Context(), userAPI, r.Username, r.Password, r.B64encodedPublicKey, "", req.RemoteAddr, req.UserAgent(),
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -739,13 +741,7 @@ func checkAndCompleteFlow(
// registerRequest, as this function serves requests encoded as both // registerRequest, as this function serves requests encoded as both
// registerRequests and legacyRegisterRequests, which share some attributes but // registerRequests and legacyRegisterRequests, which share some attributes but
// not all // not all
func completeRegistration( func completeRegistration(ctx context.Context, userAPI userapi.UserInternalAPI, username, password, b64encodedPublicKey, appserviceID, ipAddr, userAgent string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string) util.JSONResponse {
ctx context.Context,
userAPI userapi.UserInternalAPI,
username, password, appserviceID, ipAddr, userAgent string,
inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string,
) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -765,6 +761,7 @@ func completeRegistration(
AppServiceID: appserviceID, AppServiceID: appserviceID,
Localpart: username, Localpart: username,
Password: password, Password: password,
B64encodedPublicKey: b64encodedPublicKey,
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
OnConflict: userapi.ConflictAbort, OnConflict: userapi.ConflictAbort,
}, &accRes) }, &accRes)
@ -963,5 +960,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
return *resErr return *resErr
} }
deviceID := "shared_secret_registration" deviceID := "shared_secret_registration"
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID) return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID)
} }

View file

@ -62,7 +62,7 @@ func Setup(
mscCfg *config.MSCs, mscCfg *config.MSCs,
) { ) {
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg) userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, accountDB.GetAccountByChallengeResponse, cfg)
unstableFeatures := map[string]bool{ unstableFeatures := map[string]bool{
"org.matrix.e2e_cross_signing": true, "org.matrix.e2e_cross_signing": true,
@ -498,7 +498,7 @@ func Setup(
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
} }
return Login(req, accountDB, userAPI, cfg) return Login(req, accountDB, userAPI, cfg, userInteractiveAuth)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)

View file

@ -16,6 +16,9 @@ package main
import ( import (
"context" "context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -39,6 +42,8 @@ Example:
# provide password by parameter # provide password by parameter
%s --config dendrite.yaml -username alice -password foobarbaz %s --config dendrite.yaml -username alice -password foobarbaz
# auto generate keypair
%s --config dendrite.yaml -username alice -password foobarbaz -udk NuJ7J4BsaE8QZT1ULNTc3s8ZjLFmDPh91l1i0Urf/ls=
# use password from file # use password from file
%s --config dendrite.yaml -username alice -passwordfile my.pass %s --config dendrite.yaml -username alice -passwordfile my.pass
# ask user to provide password # ask user to provide password
@ -54,6 +59,7 @@ Arguments:
var ( var (
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
createKeypair = flag.Bool("create-keypair", false, "Whether to create an Ed25519 keypair for the account to create (optional)")
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
askPass = flag.Bool("ask-pass", false, "Ask for the password to use") askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
@ -62,7 +68,7 @@ var (
func main() { func main() {
name := os.Args[0] name := os.Args[0]
flag.Usage = func() { flag.Usage = func() {
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name) _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name)
flag.PrintDefaults() flag.PrintDefaults()
} }
cfg := setup.ParseFlags(true) cfg := setup.ParseFlags(true)
@ -81,7 +87,32 @@ func main() {
logrus.Fatalln("Failed to connect to the database:", err.Error()) logrus.Fatalln("Failed to connect to the database:", err.Error())
} }
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "") var pub64 string
if *createKeypair {
pub, priv, err2 := ed25519.GenerateKey(rand.Reader)
pub64 = base64.StdEncoding.EncodeToString(priv.Public().(ed25519.PublicKey))
if err2 != nil {
logrus.Fatalln(err2)
}
err2 = os.WriteFile("private.key", priv, 0644)
if err2 != nil {
logrus.Fatalln(err2)
}
err2 = os.WriteFile("private.key.seed", priv.Seed(), 0644)
if err2 != nil {
logrus.Fatalln(err2)
}
err2 = os.WriteFile("./public.key", pub, 0644)
if err2 != nil {
logrus.Fatalln(err2)
}
err2 = os.WriteFile("./public.key.b64", []byte(pub64), 0644)
if err2 != nil {
logrus.Fatalln(err2)
}
}
_, err = accountDB.CreateAccount(context.Background(), *username, pass, pub64, "")
if err != nil { if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error()) logrus.Fatalln("Failed to create the account:", err.Error())
} }

View file

@ -0,0 +1,61 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"crypto/ed25519"
"encoding/base64"
"flag"
"fmt"
"github.com/sirupsen/logrus"
"os"
)
const usage = `Usage: %s
Sign a string using an Ed25519 private key
Arguments:
`
var (
privateKeyFile = flag.String("private-key", "", "An Ed25519 private key seed used to sign the input")
input = flag.String("input", "", "The input to sign")
)
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, usage, os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
if *privateKeyFile == "" || *input == "" {
flag.Usage()
return
}
if *privateKeyFile != "" && *input != "" {
seedBytes, err := os.ReadFile(*privateKeyFile)
if err != nil {
logrus.Fatalln(err)
}
priv := ed25519.NewKeyFromSeed(seedBytes)
sig := ed25519.Sign(priv, []byte(*input))
logrus.Infoln("Signature: " + base64.StdEncoding.EncodeToString(sig))
}
}

View file

@ -145,6 +145,11 @@ client_api:
external_api: external_api:
listen: http://[::]:8071 listen: http://[::]:8071
# Which of the authentication flows to support on this server
login_types:
- m.login.password
- m.login.challenge_response
# Prevents new users from being able to register on this homeserver, except when # Prevents new users from being able to register on this homeserver, except when
# using the registration shared secret below. # using the registration shared secret below.
registration_disabled: false registration_disabled: false

View file

@ -12,6 +12,10 @@ type ClientAPI struct {
InternalAPI InternalAPIOptions `yaml:"internal_api"` InternalAPI InternalAPIOptions `yaml:"internal_api"`
ExternalAPI ExternalAPIOptions `yaml:"external_api"` ExternalAPI ExternalAPIOptions `yaml:"external_api"`
// What authentication mechanisms shall be supported
// by this server
LoginTypes []string `yaml:"login_types"`
// If set disables new users from registering (except via shared // If set disables new users from registering (except via shared
// secrets) // secrets)
RegistrationDisabled bool `yaml:"registration_disabled"` RegistrationDisabled bool `yaml:"registration_disabled"`
@ -45,6 +49,7 @@ func (c *ClientAPI) Defaults(generate bool) {
c.InternalAPI.Listen = "http://localhost:7771" c.InternalAPI.Listen = "http://localhost:7771"
c.InternalAPI.Connect = "http://localhost:7771" c.InternalAPI.Connect = "http://localhost:7771"
c.ExternalAPI.Listen = "http://[::]:8071" c.ExternalAPI.Listen = "http://[::]:8071"
c.LoginTypes = []string{"m.login.password"}
c.RegistrationSharedSecret = "" c.RegistrationSharedSecret = ""
c.RecaptchaPublicKey = "" c.RecaptchaPublicKey = ""
c.RecaptchaPrivateKey = "" c.RecaptchaPrivateKey = ""

View file

@ -245,6 +245,7 @@ type PerformAccountCreationRequest struct {
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
Password string // optional: if missing then this account will be a passwordless account Password string // optional: if missing then this account will be a passwordless account
B64encodedPublicKey string // optional: if missing then this account will not allow digital signature login until a public key is added
OnConflict Conflict OnConflict Conflict
} }

View file

@ -67,7 +67,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc res.Account = acc
return nil return nil
} }
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID) acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.B64encodedPublicKey, req.AppServiceID)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict { switch req.OnConflict {

View file

@ -27,6 +27,7 @@ import (
type Database interface { type Database interface {
internal.PartitionStorer internal.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
@ -34,7 +35,7 @@ type Database interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error) CreateAccount(ctx context.Context, localpart, plaintextPassword, b64encodedPublicKey, appserviceID string) (*api.Account, error)
CreateGuestAccount(ctx context.Context) (*api.Account, error) CreateGuestAccount(ctx context.Context) (*api.Account, error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)

View file

@ -36,6 +36,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account. -- The password hash for this account. Can be NULL if this is a passwordless account.
password_hash TEXT, password_hash TEXT,
-- The public key for this account, base64 encoded.
b64_public_key TEXT,
-- Identifies which application service this account belongs to, if any. -- Identifies which application service this account belongs to, if any.
appservice_id TEXT, appservice_id TEXT,
-- If the account is currently active -- If the account is currently active
@ -48,7 +50,7 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, b64_public_key, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -62,6 +64,9 @@ const selectAccountByLocalpartSQL = "" +
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
const selectb64PubKeySQL = "" +
"SELECT b64_public_key FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')" "SELECT nextval('numeric_username_seq')"
@ -71,6 +76,7 @@ type accountsStatements struct {
deactivateAccountStmt *sql.Stmt deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectb64PubKeyStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -88,6 +94,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
{&s.deactivateAccountStmt, deactivateAccountSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL},
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectb64PubKeyStmt, selectb64PubKeySQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -95,17 +102,15 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, b64PubKey, appserviceID string) (*api.Account, error) {
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if appserviceID == "" { if appserviceID == "" {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, b64PubKey, nil)
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, appserviceID)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -140,6 +145,13 @@ func (s *accountsStatements) selectPasswordHash(
return return
} }
func (s *accountsStatements) selectb64PubKey(
ctx context.Context, localpart string,
) (b64PubKey string, err error) {
err = s.selectb64PubKeyStmt.QueryRowContext(ctx, localpart).Scan(&b64PubKey)
return
}
func (s *accountsStatements) selectAccountByLocalpart( func (s *accountsStatements) selectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {

View file

@ -16,7 +16,9 @@ package postgres
import ( import (
"context" "context"
"crypto/ed25519"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -120,6 +122,29 @@ func (d *Database) GetAccountByPassword(
return d.accounts.selectAccountByLocalpart(ctx, localpart) return d.accounts.selectAccountByLocalpart(ctx, localpart)
} }
// GetAccountByChallengeResponse returns the account associated with the given localpart and public key
// if the given signature can be traced back to the public key.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error) {
b64PubKey, err := d.accounts.selectb64PubKey(ctx, localpart)
if err != nil {
return nil, err
}
pubKey, err := base64.StdEncoding.DecodeString(b64PubKey)
if err != nil {
return nil, err
}
sig, err := base64.StdEncoding.DecodeString(b64encodedSignature)
if err != nil {
return nil, err
}
verified := ed25519.Verify(pubKey, []byte(challenge), sig)
if !verified {
return nil, errors.New("Authentication error: Invalid signature")
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart. // GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart( func (d *Database) GetProfileByLocalpart(
@ -165,7 +190,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
return err return err
} }
localpart := strconv.FormatInt(numLocalpart, 10) localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "") acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
return err return err
}) })
return acc, err return acc, err
@ -174,19 +199,15 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, sqlutil.ErrUserExists. // account already exists, it will return nil, sqlutil.ErrUserExists.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(ctx context.Context, localpart, plaintextPassword, b64encodedPublicKey, appserviceID string) (acc *api.Account, err error) {
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, b64encodedPublicKey, appserviceID)
return err return err
}) })
return return
} }
func (d *Database) createAccount( func (d *Database) createAccount(ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, publicKey, appserviceID string) (*api.Account, error) {
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*api.Account, error) {
var account *api.Account var account *api.Account
var err error var err error
// Generate a password hash if this is not a password-less user // Generate a password hash if this is not a password-less user
@ -197,7 +218,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, publicKey, appserviceID); err != nil {
if sqlutil.IsUniqueConstraintViolationErr(err) { if sqlutil.IsUniqueConstraintViolationErr(err) {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }

View file

@ -36,6 +36,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account. -- The password hash for this account. Can be NULL if this is a passwordless account.
password_hash TEXT, password_hash TEXT,
-- The public key for this account, base64 encoded.
b64_public_key TEXT,
-- Identifies which application service this account belongs to, if any. -- Identifies which application service this account belongs to, if any.
appservice_id TEXT, appservice_id TEXT,
-- If the account is currently active -- If the account is currently active
@ -46,7 +48,7 @@ CREATE TABLE IF NOT EXISTS account_accounts (
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, b64_public_key, appservice_id) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -60,6 +62,9 @@ const selectAccountByLocalpartSQL = "" +
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectb64PubKeySQL = "" +
"SELECT b64_public_key FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts" "SELECT COUNT(localpart) FROM account_accounts"
@ -70,6 +75,7 @@ type accountsStatements struct {
deactivateAccountStmt *sql.Stmt deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectb64PubKeyStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -88,6 +94,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
{&s.deactivateAccountStmt, deactivateAccountSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL},
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectb64PubKeyStmt, selectb64PubKeySQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -95,17 +102,15 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, b64PubKey, appserviceID string) (*api.Account, error) {
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error var err error
if appserviceID == "" { if appserviceID == "" {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, b64PubKey, nil)
} else { } else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, appserviceID)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -140,6 +145,13 @@ func (s *accountsStatements) selectPasswordHash(
return return
} }
func (s *accountsStatements) selectb64PubKey(
ctx context.Context, localpart string,
) (b64PubKey string, err error) {
err = s.selectb64PubKeyStmt.QueryRowContext(ctx, localpart).Scan(&b64PubKey)
return
}
func (s *accountsStatements) selectAccountByLocalpart( func (s *accountsStatements) selectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {

View file

@ -23,6 +23,7 @@ CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
password_hash TEXT, password_hash TEXT,
b64_public_key TEXT,
appservice_id TEXT, appservice_id TEXT,
is_deactivated BOOLEAN DEFAULT 0 is_deactivated BOOLEAN DEFAULT 0
); );
@ -47,6 +48,7 @@ CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
password_hash TEXT, password_hash TEXT,
b64_public_key TEXT,
appservice_id TEXT appservice_id TEXT
); );
INSERT INSERT

View file

@ -16,7 +16,9 @@ package sqlite3
import ( import (
"context" "context"
"crypto/ed25519"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -125,6 +127,29 @@ func (d *Database) GetAccountByPassword(
return d.accounts.selectAccountByLocalpart(ctx, localpart) return d.accounts.selectAccountByLocalpart(ctx, localpart)
} }
// GetAccountByChallengeResponse returns the account associated with the given localpart and public key
// if the given signature can be traced back to the public key.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByChallengeResponse(ctx context.Context, localpart, b64encodedSignature, challenge string) (*api.Account, error) {
b64PubKey, err := d.accounts.selectb64PubKey(ctx, localpart)
if err != nil {
return nil, err
}
pubKey, err := base64.StdEncoding.DecodeString(b64PubKey)
if err != nil {
return nil, err
}
sig, err := base64.StdEncoding.DecodeString(b64encodedSignature)
if err != nil {
return nil, err
}
verified := ed25519.Verify(pubKey, []byte(challenge), sig)
if !verified {
return nil, errors.New("Authentication error: Invalid signature")
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart. // GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart( func (d *Database) GetProfileByLocalpart(
@ -191,7 +216,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
return err return err
} }
localpart := strconv.FormatInt(numLocalpart, 10) localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "") acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
return err return err
}) })
return acc, err return acc, err
@ -200,9 +225,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(ctx context.Context, localpart, plaintextPassword, b64encodedPublicKey, appserviceID string) (acc *api.Account, err error) {
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'. // Create one account at a time else we can get 'database is locked'.
d.profilesMu.Lock() d.profilesMu.Lock()
d.accountDatasMu.Lock() d.accountDatasMu.Lock()
@ -211,7 +234,7 @@ func (d *Database) CreateAccount(
defer d.accountDatasMu.Unlock() defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock() defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, b64encodedPublicKey, appserviceID)
return err return err
}) })
return return
@ -219,9 +242,7 @@ func (d *Database) CreateAccount(
// WARNING! This function assumes that the relevant mutexes have already // WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount( func (d *Database) createAccount(ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, publicKey, appserviceID string) (*api.Account, error) {
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*api.Account, error) {
var err error var err error
var account *api.Account var account *api.Account
// Generate a password hash if this is not a password-less user // Generate a password hash if this is not a password-less user
@ -232,7 +253,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, publicKey, appserviceID); err != nil {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {

View file

@ -48,7 +48,7 @@ func TestQueryProfile(t *testing.T) {
aliceAvatarURL := "mxc://example.com/alice" aliceAvatarURL := "mxc://example.com/alice"
aliceDisplayName := "Alice" aliceDisplayName := "Alice"
userAPI, accountDB := MustMakeInternalAPI(t) userAPI, accountDB := MustMakeInternalAPI(t)
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "")
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }