Auto-generate username if none provided during registration

Signed-off-by: Anant Prakash <anantprakashjsr@gmail.com>
This commit is contained in:
Anant Prakash 2018-03-11 22:37:49 +05:30
parent 6b55972183
commit 96f5b66eae
No known key found for this signature in database
GPG key ID: C5D399F626523045
3 changed files with 103 additions and 1 deletions

View file

@ -49,12 +49,16 @@ const selectAccountByLocalpartSQL = "" +
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1"
const selectAllLocalpartSQL = "" +
"SELECT localpart FROM account_accounts"
// TODO: Update password
type accountsStatements struct {
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectAllLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
@ -72,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return
}
if s.selectAllLocalpartStmt, err = db.Prepare(selectAllLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
@ -123,6 +130,31 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, err
}
func (s *accountsStatements) selectAllLocalparts(ctx context.Context) ([]string, error) {
var localparts []string
stmt := s.selectAllLocalpartStmt
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
for rows.Next() {
var t string
err := rows.Scan(&t)
if err != nil {
return nil, err
}
localparts = append(localparts, t)
}
// Ensure that rows is closed in case of error.
return localparts, rows.Close()
}
func makeUserID(localpart string, server gomatrixserverlib.ServerName) string {
return fmt.Sprintf("@%s:%s", localpart, string(server))
}

View file

@ -358,3 +358,10 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
}
return false, err
}
// GetAllLocalparts returns an slice header containing all localparts in database.
// If no records exist it returns an empty slice.
// Returns an error if something goes wrong.
func (d *Database) GetAllLocalparts(ctx context.Context) ([]string, error) {
return d.accounts.selectAllLocalparts(ctx)
}

View file

@ -27,6 +27,7 @@ import (
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"
@ -237,6 +238,38 @@ func validateRecaptcha(
return nil
}
// generates username for registration if none provided.
// Sequentially checks for available username and returns the lowest available one.
func generateUsername(
ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite,
) (string, error) {
numericLps := make(map[string]bool)
// re matches localparts with only digits
re := regexp.MustCompile(`\A\d+\z`)
localparts, err := accountDB.GetAllLocalparts(ctx)
if err != nil {
return "", err
}
for _, lp := range localparts {
if re.MatchString(lp) {
// Length == 2 implies there is a matched localpart
numericLps[lp] = true
}
}
for i := 1; ; i++ {
lp := strconv.Itoa(i)
if _, ok := numericLps[lp]; !ok && !UsernameMatchesExclusiveNamespace(
cfg, lp,
) {
return lp, nil
}
}
}
// UsernameIsWithinApplicationServiceNamespace checks to see if a username falls
// within any of the namespaces of a given Application Service. If no
// Application Service is given, it will check to see if it matches any
@ -269,6 +302,17 @@ func UsernameIsWithinApplicationServiceNamespace(
return false
}
// UsernameMatchesExclusiveNamespace will check if a given username matches
// an exclusive namespace. localpart should not match Application Service namespace.
func UsernameMatchesExclusiveNamespace(
cfg *config.Dendrite,
username string,
) bool {
// Check namespaces and see if there's a match
matchCount := countMatchingNamespaces(cfg, username)
return matchCount > 0
}
// UsernameMatchesMultipleExclusiveNamespaces will check if a given username matches
// more than one exclusive namespace. More than one is not allowed
func UsernameMatchesMultipleExclusiveNamespaces(
@ -276,6 +320,15 @@ func UsernameMatchesMultipleExclusiveNamespaces(
username string,
) bool {
// Check namespaces and see if more than one match
matchCount := countMatchingNamespaces(cfg, username)
return matchCount > 1
}
func countMatchingNamespaces(
cfg *config.Dendrite,
username string,
) int {
// Check namespaces and count matches
matchCount := 0
for _, appservice := range cfg.Derived.ApplicationServices {
for _, namespaceSlice := range appservice.NamespaceMap {
@ -287,7 +340,7 @@ func UsernameMatchesMultipleExclusiveNamespaces(
}
}
}
return matchCount > 1
return matchCount
}
// validateApplicationService checks if a provided application service token
@ -361,6 +414,16 @@ func Register(
sessionID = util.RandomString(sessionIDLength)
}
// auto generate a numeric username if r.Username empty
if r.Username == "" {
var err error
r.Username, err = generateUsername(req.Context(), accountDB, cfg)
if err != nil {
return jsonerror.InternalServerError()
}
}
// If no auth type is specified by the client, send back the list of available flows
if r.Auth.Type == "" {
return util.JSONResponse{