From 2c339a6bfd90c148ab9524c66ee7e7301479dc72 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Mon, 5 Jun 2023 23:10:27 +0530 Subject: [PATCH] refactoring, implement db layer --- clientapi/routing/admin.go | 48 ++++++++++++++----- roomserver/internal/perform/perform_admin.go | 10 +++- roomserver/storage/interface.go | 1 + .../postgres/registration_tokens_table.go | 27 +++++++++++ roomserver/storage/postgres/storage.go | 3 ++ roomserver/storage/shared/storage.go | 23 +++++---- roomserver/storage/tables/interface.go | 4 ++ 7 files changed, 92 insertions(+), 24 deletions(-) create mode 100644 roomserver/storage/postgres/registration_tokens_table.go diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 60d6ec7dd..558e011e3 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -5,8 +5,10 @@ import ( "encoding/json" "errors" "fmt" + "math/rand" "net/http" "regexp" + "strings" "time" "github.com/gorilla/mux" @@ -25,6 +27,17 @@ import ( "github.com/matrix-org/dendrite/userapi/api" ) +func generateRandomToken(length int) string { + allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + rand.Seed(time.Now().UnixNano()) + var sb strings.Builder + for i := 0; i < length; i++ { + randomIndex := rand.Intn(len(allowedChars)) + sb.WriteByte(allowedChars[randomIndex]) + } + return sb.String() +} + func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { return util.MatrixErrorResponse( @@ -47,13 +60,31 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r "Failed to decode request body:", ) } + token := request.Token - if len(token) == 0 || len(token) > 64 { + usesAllowed := request.UsesAllowed + expiryTime := request.ExpiryTime + length := request.Length + + if len(token) == 0 { + // Token not present in request body. Hence, generate a random token. + if !(length > 0 && length <= 64) { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "length must be greater than zero and not greater than 64") + } + token = generateRandomToken(int(length)) + } + + if len(token) > 64 { + //Token present in request body, but is too long. return util.MatrixErrorResponse( http.StatusBadRequest, string(spec.ErrorInvalidParam), - "token must not be empty and must not be longer than 64") + "token must not be longer than 64") } + isTokenValid, _ := regexp.MatchString("^[[:ascii:][:digit:]_]*$", token) if !isTokenValid { return util.MatrixErrorResponse( @@ -61,16 +92,8 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r string(spec.ErrorInvalidParam), "token must consist only of characters matched by the regex [A-Za-z0-9-_]") } - length := request.Length - if !(length > 0 && length <= 64) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "length must be greater than zero and not greater than 64") - } - // TODO: Generate Random Token - // token = GenerateRandomToken(length) - usesAllowed := request.UsesAllowed + // At this point, we have a valid token, either through request body or through random generation. + if usesAllowed < 0 { return util.MatrixErrorResponse( http.StatusBadRequest, @@ -78,7 +101,6 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r "uses_allowed must be a non-negative integer or null") } - expiryTime := request.ExpiryTime if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) { return util.MatrixErrorResponse( http.StatusBadRequest, diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index f78886035..292d91f23 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -46,8 +46,14 @@ func (r *Admin) PerformAdminCreateRegistrationToken( ctx context.Context, token string, usesAllowed, pending, completed int32, expiryTime int64) (bool, error) { - //TODO: Implement logic to save token in DB. - //Return false, if token already exists, else true. + exists, err := r.DB.RegistrationTokenExists(ctx, token) + if err != nil { + return false, err + } + if exists { + fmt.Println(fmt.Sprintf("token: %s already exists", token)) + return false, fmt.Errorf("token: %s already exists", token) + } return true, nil } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7d22df008..4cf8f3b3a 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -27,6 +27,7 @@ import ( ) type Database interface { + RegistrationTokenExists(ctx context.Context, token string) (bool, error) // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool // RoomInfo returns room information for the given room ID, or nil if there is no room. diff --git a/roomserver/storage/postgres/registration_tokens_table.go b/roomserver/storage/postgres/registration_tokens_table.go new file mode 100644 index 000000000..1f69f42d8 --- /dev/null +++ b/roomserver/storage/postgres/registration_tokens_table.go @@ -0,0 +1,27 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +func CreateRegistrationTokensTable(db *sql.DB) error { + _, err := db.Exec(registrationTokensSchema) + return err +} + +func RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + fmt.Println("here!!") + return true, nil +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 19cde5410..5836ab153 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -92,6 +92,9 @@ func executeMigration(ctx context.Context, db *sql.DB) error { } func (d *Database) create(db *sql.DB) error { + if err := CreateRegistrationTokensTable(db); err != nil { + return err + } if err := CreateEventStateKeysTable(db); err != nil { return err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index cefa58a3d..3e316b882 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -46,15 +46,20 @@ type Database struct { // EventDatabase contains all tables needed to work with events type EventDatabase struct { - DB *sql.DB - Cache caching.RoomServerCaches - Writer sqlutil.Writer - EventsTable tables.Events - EventJSONTable tables.EventJSON - EventTypesTable tables.EventTypes - EventStateKeysTable tables.EventStateKeys - PrevEventsTable tables.PreviousEvents - RedactionsTable tables.Redactions + DB *sql.DB + Cache caching.RoomServerCaches + Writer sqlutil.Writer + EventsTable tables.Events + EventJSONTable tables.EventJSON + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys + PrevEventsTable tables.PreviousEvents + RedactionsTable tables.Redactions + RegistrationTokensTable tables.RegistrationTokens +} + +func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { + return d.RegistrationTokensTable.RegistrationTokenExists(ctx, nil, token) } func (d *Database) SupportsConcurrentRoomInputs() bool { diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 333483b32..471b341eb 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -19,6 +19,10 @@ type EventJSONPair struct { EventJSON []byte } +type RegistrationTokens interface { + RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) +} + type EventJSON interface { // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error