diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 7f6d5105e..a26e3672c 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -15,18 +15,25 @@ package main import ( + "bytes" "context" + "crypto/hmac" + "crypto/sha1" + "encoding/hex" + "encoding/json" "flag" "fmt" + "github.com/tidwall/gjson" "io" "io/ioutil" + "net/http" "os" "regexp" "strings" + "time" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" "golang.org/x/term" @@ -59,9 +66,9 @@ var ( password = flag.String("password", "", "The password to associate with the account") pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - pwdLess = flag.Bool("passwordless", false, "Create a passwordless account, e.g. if only an accesstoken is required") isAdmin = flag.Bool("admin", false, "Create an admin account") resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") + serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.") validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) ) @@ -78,10 +85,6 @@ func main() { os.Exit(1) } - if *pwdLess && *resetPassword { - logrus.Fatalf("Can not reset to an empty password, unable to login afterwards.") - } - if !validUsernameRegex.MatchString(*username) { logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") os.Exit(1) @@ -91,44 +94,31 @@ func main() { logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) } - var pass string - var err error - if !*pwdLess { - pass, err = getPassword(*password, *pwdFile, *pwdStdin, os.Stdin) - if err != nil { - logrus.Fatalln(err) - } - } - - // avoid warning about open registration - cfg.ClientAPI.RegistrationDisabled = true - - b := base.NewBaseDendrite(cfg, "") - defer b.Close() // nolint: errcheck - - accountDB, err := storage.NewUserAPIDatabase( - b, - &cfg.UserAPI.AccountDatabase, - cfg.Global.ServerName, - cfg.UserAPI.BCryptCost, - cfg.UserAPI.OpenIDTokenLifetimeMS, - 0, // TODO - cfg.Global.ServerNotices.LocalPart, - ) + pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin) if err != nil { - logrus.WithError(err).Fatalln("Failed to connect to the database") + logrus.Fatalln(err) } - accType := api.AccountTypeUser - if *isAdmin { - accType = api.AccountTypeAdmin - } - - available, err := accountDB.CheckAccountAvailability(context.Background(), *username) - if err != nil { - logrus.Fatalln("Unable check username existence.") - } if *resetPassword { + b := base.NewBaseDendrite(cfg, "") + defer b.Close() // nolint: errcheck + accountDB, err := storage.NewUserAPIDatabase( + b, + &cfg.UserAPI.AccountDatabase, + cfg.Global.ServerName, + cfg.UserAPI.BCryptCost, + cfg.UserAPI.OpenIDTokenLifetimeMS, + 0, // TODO + cfg.Global.ServerNotices.LocalPart, + ) + if err != nil { + logrus.WithError(err).Fatalln("Failed to connect to the database") + } + + available, err := accountDB.CheckAccountAvailability(context.Background(), *username) + if err != nil { + logrus.Fatalln("Unable check username existence.") + } if available { logrus.Fatalln("Username could not be found.") } @@ -142,16 +132,95 @@ func main() { logrus.Infof("Updated password for user %s and invalidated all logins\n", *username) return } - if !available { - logrus.Fatalln("Username is already in use.") - } - _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType) + accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) } - logrus.Infoln("Created account", *username) + logrus.Infof("Created account: %s (AccessToken: %s)", *username, accessToken) +} + +type sharedSecretRegistrationRequest struct { + User string `json:"username"` + Password string `json:"password"` + Nonce string `json:"nonce"` + MacStr string `json:"mac"` + Admin bool `json:"admin"` +} + +func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, admin bool) (accesToken string, err error) { + registerURL := fmt.Sprintf("%s/_synapse/admin/v1/register", serverURL) + cl := http.Client{ + Timeout: time.Second * 10, + Transport: http.DefaultTransport, + } + nonceReq, err := http.NewRequest(http.MethodGet, registerURL, nil) + if err != nil { + return "", fmt.Errorf("unable to create http request: %w", err) + } + + nonceResp, err := cl.Do(nonceReq) + if err != nil { + return "", fmt.Errorf("unable to get nonce: %w", err) + } + body, err := ioutil.ReadAll(nonceResp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + defer nonceResp.Body.Close() + + nonce := gjson.GetBytes(body, "nonce").Str + + adminStr := "notadmin" + if admin { + adminStr = "admin" + } + reg := sharedSecretRegistrationRequest{ + User: localpart, + Password: password, + Nonce: nonce, + Admin: admin, + } + macStr, err := getRegisterMac(sharedSecret, nonce, localpart, password, adminStr) + if err != nil { + return "", err + } + reg.MacStr = macStr + + js, err := json.Marshal(reg) + if err != nil { + return "", fmt.Errorf("unable to marshal json: %w", err) + } + registerReq, err := http.NewRequest(http.MethodPost, registerURL, bytes.NewBuffer(js)) + if err != nil { + return "", fmt.Errorf("unable to create http request: %w", err) + + } + regResp, err := cl.Do(registerReq) + if err != nil { + return "", fmt.Errorf("unable to create account: %w", err) + } + defer regResp.Body.Close() + if regResp.StatusCode < 200 || regResp.StatusCode >= 300 { + body, _ = ioutil.ReadAll(regResp.Body) + return "", fmt.Errorf(gjson.GetBytes(body, "error").Str) + } + r, _ := ioutil.ReadAll(regResp.Body) + + return gjson.GetBytes(r, "access_token").Str, nil +} + +func getRegisterMac(sharedSecret, nonce, localpart, password, adminStr string) (string, error) { + joined := strings.Join([]string{nonce, localpart, password, adminStr}, "\x00") + mac := hmac.New(sha1.New, []byte(sharedSecret)) + _, err := mac.Write([]byte(joined)) + if err != nil { + return "", fmt.Errorf("unable to construct mac: %w", err) + } + regMac := mac.Sum(nil) + + return hex.EncodeToString(regMac), nil } func getPassword(password, pwdFile string, pwdStdin bool, r io.Reader) (string, error) {