Auto-generate username if none provided during registration (#470)
* Auto-generate username if none provided during registration * Remove rogue backtick * Add appropriate log msg
This commit is contained in:
parent
05be8d1c99
commit
1f570d0e92
|
@ -38,6 +38,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
|
-- Create sequence for autogenerated numeric usernames
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
|
@ -49,13 +51,17 @@ const selectAccountByLocalpartSQL = "" +
|
||||||
const selectPasswordHashSQL = "" +
|
const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1"
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
|
"SELECT nextval('numeric_username_seq')"
|
||||||
|
|
||||||
// TODO: Update password
|
// TODO: Update password
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
@ -72,6 +78,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
|
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
s.serverName = server
|
s.serverName = server
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -121,3 +130,10 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
|
ctx context.Context,
|
||||||
|
) (id int64, err error) {
|
||||||
|
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -267,6 +267,13 @@ func (d *Database) GetAccountDataByType(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||||
|
func (d *Database) GetNewNumericLocalpart(
|
||||||
|
ctx context.Context,
|
||||||
|
) (int64, error) {
|
||||||
|
return d.accounts.selectNewNumericLocalpart(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func hashPassword(plaintext string) (hash string, err error) {
|
func hashPassword(plaintext string) (hash string, err error) {
|
||||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
||||||
return string(hashBytes), err
|
return string(hashBytes), err
|
||||||
|
|
|
@ -27,6 +27,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -403,6 +404,23 @@ func Register(
|
||||||
sessionID = util.RandomString(sessionIDLength)
|
sessionID = util.RandomString(sessionIDLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Don't allow numeric usernames less than MAX_INT64.
|
||||||
|
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Auto generate a numeric username if r.Username is empty
|
||||||
|
if r.Username == "" {
|
||||||
|
id, err := accountDB.GetNewNumericLocalpart(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Username = strconv.FormatInt(id, 10)
|
||||||
|
}
|
||||||
|
|
||||||
// If no auth type is specified by the client, send back the list of available flows
|
// If no auth type is specified by the client, send back the list of available flows
|
||||||
if r.Auth.Type == "" {
|
if r.Auth.Type == "" {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
Loading…
Reference in a new issue