diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 060b82f97..a1e254f8d 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -18,13 +18,18 @@ import ( "context" "flag" "fmt" + "io" + "io/ioutil" "os" + "strings" + "syscall" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" + "golang.org/x/term" ) const usage = `Usage: %s @@ -33,7 +38,15 @@ Creates a new user account on the homeserver. Example: - ./create-account --config dendrite.yaml --username alice --password foobarbaz + # provide password by parameter + %s --config dendrite.yaml -username alice -password foobarbaz + # use password from file + %s --config dendrite.yaml -username alice -passwordfile my.pass + # ask user to provide password + %s --config dendrite.yaml -username alice -ask-pass + # read password from stdin + %s --config dendrite.yaml -username alice -passwordstdin < my.pass + cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin Arguments: @@ -42,11 +55,15 @@ Arguments: var ( username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + askPass = flag.Bool("ask-pass", false, "Ask for the password to use") ) func main() { + name := os.Args[0] flag.Usage = func() { - fmt.Fprintf(os.Stderr, usage, os.Args[0]) + _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name) flag.PrintDefaults() } cfg := setup.ParseFlags(true) @@ -56,6 +73,8 @@ func main() { os.Exit(1) } + pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) + accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, }, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) @@ -63,10 +82,61 @@ func main() { logrus.Fatalln("Failed to connect to the database:", err.Error()) } - _, err = accountDB.CreateAccount(context.Background(), *username, *password, "") + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "") if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) } logrus.Infoln("Created account", *username) } + +func getPassword(password, pwdFile *string, pwdStdin, askPass *bool, r io.Reader) string { + // no password option set, use empty password + if password == nil && pwdFile == nil && pwdStdin == nil && askPass == nil { + return "" + } + // password defined as parameter + if password != nil && *password != "" { + return *password + } + + // read password from file + if pwdFile != nil && *pwdFile != "" { + pw, err := ioutil.ReadFile(*pwdFile) + if err != nil { + logrus.Fatalln("Unable to read password from file:", err) + } + return strings.TrimSpace(string(pw)) + } + + // read password from stdin + if pwdStdin != nil && *pwdStdin { + data, err := ioutil.ReadAll(r) + if err != nil { + logrus.Fatalln("Unable to read password from stdin:", err) + } + return strings.TrimSpace(string(data)) + } + + // ask the user to provide the password + if *askPass { + fmt.Print("Enter Password: ") + bytePassword, err := term.ReadPassword(syscall.Stdin) + if err != nil { + logrus.Fatalln("Unable to read password:", err) + } + fmt.Println() + fmt.Print("Confirm Password: ") + bytePassword2, err := term.ReadPassword(syscall.Stdin) + if err != nil { + logrus.Fatalln("Unable to read password:", err) + } + fmt.Println() + if strings.TrimSpace(string(bytePassword)) != strings.TrimSpace(string(bytePassword2)) { + logrus.Fatalln("Entered passwords don't match") + } + return strings.TrimSpace(string(bytePassword)) + } + + return "" +} diff --git a/cmd/create-account/main_test.go b/cmd/create-account/main_test.go new file mode 100644 index 000000000..d06eafe46 --- /dev/null +++ b/cmd/create-account/main_test.go @@ -0,0 +1,62 @@ +package main + +import ( + "bytes" + "io" + "testing" +) + +func Test_getPassword(t *testing.T) { + type args struct { + password *string + pwdFile *string + pwdStdin *bool + askPass *bool + reader io.Reader + } + + pass := "mySecretPass" + passwordFile := "testdata/my.pass" + passwordStdin := true + reader := &bytes.Buffer{} + _, err := reader.WriteString(pass) + if err != nil { + t.Errorf("unable to write to buffer: %+v", err) + } + tests := []struct { + name string + args args + want string + }{ + { + name: "no password defined", + args: args{}, + want: "", + }, + { + name: "password defined", + args: args{password: &pass}, + want: pass, + }, + { + name: "pwdFile defined", + args: args{pwdFile: &passwordFile}, + want: pass, + }, + { + name: "read pass from stdin defined", + args: args{ + pwdStdin: &passwordStdin, + reader: reader, + }, + want: pass, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getPassword(tt.args.password, tt.args.pwdFile, tt.args.pwdStdin, tt.args.askPass, tt.args.reader); got != tt.want { + t.Errorf("getPassword() = '%v', want '%v'", got, tt.want) + } + }) + } +} diff --git a/cmd/create-account/testdata/my.pass b/cmd/create-account/testdata/my.pass new file mode 100644 index 000000000..c1f7156f0 --- /dev/null +++ b/cmd/create-account/testdata/my.pass @@ -0,0 +1 @@ +mySecretPass \ No newline at end of file