Auto-generate username if none provided during registration (#470)

* Auto-generate username if none provided during registration

* Remove rogue backtick

* Add appropriate log msg
This commit is contained in:
Anant Prakash 2018-05-31 20:06:15 +05:30 committed by Andrew Morgan
parent 05be8d1c99
commit 1f570d0e92
3 changed files with 45 additions and 4 deletions

View file

@ -38,6 +38,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- TODO: -- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
); );
-- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
@ -49,12 +51,16 @@ const selectAccountByLocalpartSQL = "" +
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1" "SELECT password_hash FROM account_accounts WHERE localpart = $1"
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
// TODO: Update password // TODO: Update password
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -72,6 +78,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return return
} }
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -121,3 +130,10 @@ func (s *accountsStatements) selectAccountByLocalpart(
} }
return return
} }
func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context,
) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
return
}

View file

@ -267,6 +267,13 @@ func (d *Database) GetAccountDataByType(
) )
} }
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}
func hashPassword(plaintext string) (hash string, err error) { func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost) hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err return string(hashBytes), err

View file

@ -27,6 +27,7 @@ import (
"net/url" "net/url"
"regexp" "regexp"
"sort" "sort"
"strconv"
"strings" "strings"
"time" "time"
@ -403,6 +404,23 @@ func Register(
sessionID = util.RandomString(sessionIDLength) sessionID = util.RandomString(sessionIDLength)
} }
// Don't allow numeric usernames less than MAX_INT64.
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
}
}
// Auto generate a numeric username if r.Username is empty
if r.Username == "" {
id, err := accountDB.GetNewNumericLocalpart(req.Context())
if err != nil {
return httputil.LogThenError(req, err)
}
r.Username = strconv.FormatInt(id, 10)
}
// If no auth type is specified by the client, send back the list of available flows // If no auth type is specified by the client, send back the list of available flows
if r.Auth.Type == "" { if r.Auth.Type == "" {
return util.JSONResponse{ return util.JSONResponse{