mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-12 09:23:09 -06:00
Auto-generate username if none provided during registration
Signed-off-by: Anant Prakash <anantprakashjsr@gmail.com>
This commit is contained in:
parent
6b55972183
commit
96f5b66eae
|
|
@ -49,12 +49,16 @@ const selectAccountByLocalpartSQL = "" +
|
|||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1"
|
||||
|
||||
const selectAllLocalpartSQL = "" +
|
||||
"SELECT localpart FROM account_accounts"
|
||||
|
||||
// TODO: Update password
|
||||
|
||||
type accountsStatements struct {
|
||||
insertAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
selectAllLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
|
|
@ -72,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAllLocalpartStmt, err = db.Prepare(selectAllLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
}
|
||||
|
|
@ -123,6 +130,31 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||
return &acc, err
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectAllLocalparts(ctx context.Context) ([]string, error) {
|
||||
var localparts []string
|
||||
|
||||
stmt := s.selectAllLocalpartStmt
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var t string
|
||||
err := rows.Scan(&t)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
localparts = append(localparts, t)
|
||||
}
|
||||
|
||||
// Ensure that rows is closed in case of error.
|
||||
return localparts, rows.Close()
|
||||
}
|
||||
|
||||
func makeUserID(localpart string, server gomatrixserverlib.ServerName) string {
|
||||
return fmt.Sprintf("@%s:%s", localpart, string(server))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -358,3 +358,10 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
|
|||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// GetAllLocalparts returns an slice header containing all localparts in database.
|
||||
// If no records exist it returns an empty slice.
|
||||
// Returns an error if something goes wrong.
|
||||
func (d *Database) GetAllLocalparts(ctx context.Context) ([]string, error) {
|
||||
return d.accounts.selectAllLocalparts(ctx)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import (
|
|||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -237,6 +238,38 @@ func validateRecaptcha(
|
|||
return nil
|
||||
}
|
||||
|
||||
// generates username for registration if none provided.
|
||||
// Sequentially checks for available username and returns the lowest available one.
|
||||
func generateUsername(
|
||||
ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite,
|
||||
) (string, error) {
|
||||
numericLps := make(map[string]bool)
|
||||
|
||||
// re matches localparts with only digits
|
||||
re := regexp.MustCompile(`\A\d+\z`)
|
||||
localparts, err := accountDB.GetAllLocalparts(ctx)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, lp := range localparts {
|
||||
if re.MatchString(lp) {
|
||||
// Length == 2 implies there is a matched localpart
|
||||
numericLps[lp] = true
|
||||
}
|
||||
}
|
||||
|
||||
for i := 1; ; i++ {
|
||||
lp := strconv.Itoa(i)
|
||||
if _, ok := numericLps[lp]; !ok && !UsernameMatchesExclusiveNamespace(
|
||||
cfg, lp,
|
||||
) {
|
||||
return lp, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UsernameIsWithinApplicationServiceNamespace checks to see if a username falls
|
||||
// within any of the namespaces of a given Application Service. If no
|
||||
// Application Service is given, it will check to see if it matches any
|
||||
|
|
@ -269,6 +302,17 @@ func UsernameIsWithinApplicationServiceNamespace(
|
|||
return false
|
||||
}
|
||||
|
||||
// UsernameMatchesExclusiveNamespace will check if a given username matches
|
||||
// an exclusive namespace. localpart should not match Application Service namespace.
|
||||
func UsernameMatchesExclusiveNamespace(
|
||||
cfg *config.Dendrite,
|
||||
username string,
|
||||
) bool {
|
||||
// Check namespaces and see if there's a match
|
||||
matchCount := countMatchingNamespaces(cfg, username)
|
||||
return matchCount > 0
|
||||
}
|
||||
|
||||
// UsernameMatchesMultipleExclusiveNamespaces will check if a given username matches
|
||||
// more than one exclusive namespace. More than one is not allowed
|
||||
func UsernameMatchesMultipleExclusiveNamespaces(
|
||||
|
|
@ -276,6 +320,15 @@ func UsernameMatchesMultipleExclusiveNamespaces(
|
|||
username string,
|
||||
) bool {
|
||||
// Check namespaces and see if more than one match
|
||||
matchCount := countMatchingNamespaces(cfg, username)
|
||||
return matchCount > 1
|
||||
}
|
||||
|
||||
func countMatchingNamespaces(
|
||||
cfg *config.Dendrite,
|
||||
username string,
|
||||
) int {
|
||||
// Check namespaces and count matches
|
||||
matchCount := 0
|
||||
for _, appservice := range cfg.Derived.ApplicationServices {
|
||||
for _, namespaceSlice := range appservice.NamespaceMap {
|
||||
|
|
@ -287,7 +340,7 @@ func UsernameMatchesMultipleExclusiveNamespaces(
|
|||
}
|
||||
}
|
||||
}
|
||||
return matchCount > 1
|
||||
return matchCount
|
||||
}
|
||||
|
||||
// validateApplicationService checks if a provided application service token
|
||||
|
|
@ -361,6 +414,16 @@ func Register(
|
|||
sessionID = util.RandomString(sessionIDLength)
|
||||
}
|
||||
|
||||
// auto generate a numeric username if r.Username empty
|
||||
if r.Username == "" {
|
||||
var err error
|
||||
r.Username, err = generateUsername(req.Context(), accountDB, cfg)
|
||||
|
||||
if err != nil {
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
}
|
||||
|
||||
// If no auth type is specified by the client, send back the list of available flows
|
||||
if r.Auth.Type == "" {
|
||||
return util.JSONResponse{
|
||||
|
|
|
|||
Loading…
Reference in a new issue