From 96f5b66eae7cc695f9ff841d1cdc925b03c5840b Mon Sep 17 00:00:00 2001 From: Anant Prakash Date: Sun, 11 Mar 2018 22:37:49 +0530 Subject: [PATCH] Auto-generate username if none provided during registration Signed-off-by: Anant Prakash --- .../auth/storage/accounts/accounts_table.go | 32 +++++++++ .../auth/storage/accounts/storage.go | 7 ++ .../dendrite/clientapi/routing/register.go | 65 ++++++++++++++++++- 3 files changed, 103 insertions(+), 1 deletion(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go index a29d616e9..c20a905c9 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go @@ -49,12 +49,16 @@ const selectAccountByLocalpartSQL = "" + const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1" +const selectAllLocalpartSQL = "" + + "SELECT localpart FROM account_accounts" + // TODO: Update password type accountsStatements struct { insertAccountStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt + selectAllLocalpartStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -72,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { return } + if s.selectAllLocalpartStmt, err = db.Prepare(selectAllLocalpartSQL); err != nil { + return + } s.serverName = server return } @@ -123,6 +130,31 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, err } +func (s *accountsStatements) selectAllLocalparts(ctx context.Context) ([]string, error) { + var localparts []string + + stmt := s.selectAllLocalpartStmt + rows, err := stmt.QueryContext(ctx) + + if err != nil { + return nil, err + } + + for rows.Next() { + var t string + err := rows.Scan(&t) + + if err != nil { + return nil, err + } + + localparts = append(localparts, t) + } + + // Ensure that rows is closed in case of error. + return localparts, rows.Close() +} + func makeUserID(localpart string, server gomatrixserverlib.ServerName) string { return fmt.Sprintf("@%s:%s", localpart, string(server)) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 571482739..bfa1f6148 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -358,3 +358,10 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin } return false, err } + +// GetAllLocalparts returns an slice header containing all localparts in database. +// If no records exist it returns an empty slice. +// Returns an error if something goes wrong. +func (d *Database) GetAllLocalparts(ctx context.Context) ([]string, error) { + return d.accounts.selectAllLocalparts(ctx) +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index 77e875ec1..f37bd1622 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -27,6 +27,7 @@ import ( "net/url" "regexp" "sort" + "strconv" "strings" "time" @@ -237,6 +238,38 @@ func validateRecaptcha( return nil } +// generates username for registration if none provided. +// Sequentially checks for available username and returns the lowest available one. +func generateUsername( + ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite, +) (string, error) { + numericLps := make(map[string]bool) + + // re matches localparts with only digits + re := regexp.MustCompile(`\A\d+\z`) + localparts, err := accountDB.GetAllLocalparts(ctx) + + if err != nil { + return "", err + } + + for _, lp := range localparts { + if re.MatchString(lp) { + // Length == 2 implies there is a matched localpart + numericLps[lp] = true + } + } + + for i := 1; ; i++ { + lp := strconv.Itoa(i) + if _, ok := numericLps[lp]; !ok && !UsernameMatchesExclusiveNamespace( + cfg, lp, + ) { + return lp, nil + } + } +} + // UsernameIsWithinApplicationServiceNamespace checks to see if a username falls // within any of the namespaces of a given Application Service. If no // Application Service is given, it will check to see if it matches any @@ -269,6 +302,17 @@ func UsernameIsWithinApplicationServiceNamespace( return false } +// UsernameMatchesExclusiveNamespace will check if a given username matches +// an exclusive namespace. localpart should not match Application Service namespace. +func UsernameMatchesExclusiveNamespace( + cfg *config.Dendrite, + username string, +) bool { + // Check namespaces and see if there's a match + matchCount := countMatchingNamespaces(cfg, username) + return matchCount > 0 +} + // UsernameMatchesMultipleExclusiveNamespaces will check if a given username matches // more than one exclusive namespace. More than one is not allowed func UsernameMatchesMultipleExclusiveNamespaces( @@ -276,6 +320,15 @@ func UsernameMatchesMultipleExclusiveNamespaces( username string, ) bool { // Check namespaces and see if more than one match + matchCount := countMatchingNamespaces(cfg, username) + return matchCount > 1 +} + +func countMatchingNamespaces( + cfg *config.Dendrite, + username string, +) int { + // Check namespaces and count matches matchCount := 0 for _, appservice := range cfg.Derived.ApplicationServices { for _, namespaceSlice := range appservice.NamespaceMap { @@ -287,7 +340,7 @@ func UsernameMatchesMultipleExclusiveNamespaces( } } } - return matchCount > 1 + return matchCount } // validateApplicationService checks if a provided application service token @@ -361,6 +414,16 @@ func Register( sessionID = util.RandomString(sessionIDLength) } + // auto generate a numeric username if r.Username empty + if r.Username == "" { + var err error + r.Username, err = generateUsername(req.Context(), accountDB, cfg) + + if err != nil { + return jsonerror.InternalServerError() + } + } + // If no auth type is specified by the client, send back the list of available flows if r.Auth.Type == "" { return util.JSONResponse{