Move DB Layer to UserAPI

This commit is contained in:
santhoshivan23 2023-06-06 21:52:05 +05:30
parent 6cd6af150b
commit fe2464fd4b
16 changed files with 133 additions and 69 deletions

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
) )
func generateRandomToken(length int) string { func generateRandomToken(length int) string {
@ -38,7 +39,7 @@ func generateRandomToken(length int) string {
return sb.String() return sb.String()
} }
func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
if !cfg.RegistrationRequiresToken { if !cfg.RegistrationRequiresToken {
return util.MatrixErrorResponse( return util.MatrixErrorResponse(
http.StatusForbidden, http.StatusForbidden,
@ -67,7 +68,11 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r
length := request.Length length := request.Length
if len(token) == 0 { if len(token) == 0 {
// Token not present in request body. Hence, generate a random token. if length == 0 {
// length not provided in request. Assign default value of 16.
length = 16
}
// token not present in request body. Hence, generate a random token.
if !(length > 0 && length <= 64) { if !(length > 0 && length <= 64) {
return util.MatrixErrorResponse( return util.MatrixErrorResponse(
http.StatusBadRequest, http.StatusBadRequest,
@ -109,7 +114,9 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r
} }
pending := 0 pending := 0
completed := 0 completed := 0
created, err := rsAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, int32(pending), int32(completed), expiryTime) // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating
// unlimited uses / no expiration will be persisted in DB)
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, expiryTime)
if err != nil { if err != nil {
return util.MatrixErrorResponse( return util.MatrixErrorResponse(
http.StatusInternalServerError, http.StatusInternalServerError,

View file

@ -164,7 +164,7 @@ func Setup(
} }
dendriteAdminRouter.Handle("/admin/registrationTokens/new", dendriteAdminRouter.Handle("/admin/registrationTokens/new",
httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminCreateNewRegistrationToken(req, cfg, rsAPI) return AdminCreateNewRegistrationToken(req, cfg, userAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)

View file

@ -173,7 +173,6 @@ type ClientRoomserverAPI interface {
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
// PerformRoomUpgrade upgrades a room to a newer version // PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed, pending, completed int32, expiryTime int64) (bool, error)
PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error)
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformAdminPurgeRoom(ctx context.Context, roomID string) error PerformAdminPurgeRoom(ctx context.Context, roomID string) error

View file

@ -42,21 +42,6 @@ type Admin struct {
Leaver *Leaver Leaver *Leaver
} }
func (r *Admin) PerformAdminCreateRegistrationToken(
ctx context.Context, token string,
usesAllowed, pending, completed int32,
expiryTime int64) (bool, error) {
exists, err := r.DB.RegistrationTokenExists(ctx, token)
if err != nil {
return false, err
}
if exists {
fmt.Println(fmt.Sprintf("token: %s already exists", token))
return false, fmt.Errorf("token: %s already exists", token)
}
return true, nil
}
// PerformAdminEvacuateRoom will remove all local users from the given room. // PerformAdminEvacuateRoom will remove all local users from the given room.
func (r *Admin) PerformAdminEvacuateRoom( func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context, ctx context.Context,

View file

@ -27,7 +27,6 @@ import (
) )
type Database interface { type Database interface {
RegistrationTokenExists(ctx context.Context, token string) (bool, error)
// Do we support processing input events for more than one room at a time? // Do we support processing input events for more than one room at a time?
SupportsConcurrentRoomInputs() bool SupportsConcurrentRoomInputs() bool
// RoomInfo returns room information for the given room ID, or nil if there is no room. // RoomInfo returns room information for the given room ID, or nil if there is no room.

View file

@ -1,25 +0,0 @@
package postgres
import (
"context"
"database/sql"
)
const registrationTokensSchema = `
CREATE TABLE IF NOT EXISTS roomserver_registration_tokens (
token TEXT PRIMARY KEY,
pending BIGINT,
completed BIGINT,
uses_allowed BIGINT,
expiry_time BIGINT
);
`
func CreateRegistrationTokensTable(db *sql.DB) error {
_, err := db.Exec(registrationTokensSchema)
return err
}
func RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
return true, nil
}

View file

@ -92,9 +92,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error {
} }
func (d *Database) create(db *sql.DB) error { func (d *Database) create(db *sql.DB) error {
if err := CreateRegistrationTokensTable(db); err != nil {
return err
}
if err := CreateEventStateKeysTable(db); err != nil { if err := CreateEventStateKeysTable(db); err != nil {
return err return err
} }

View file

@ -31,18 +31,17 @@ const redactionsArePermanent = true
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
EventDatabase EventDatabase
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
Writer sqlutil.Writer Writer sqlutil.Writer
RoomsTable tables.Rooms RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases RoomAliasesTable tables.RoomAliases
InvitesTable tables.Invites InvitesTable tables.Invites
MembershipTable tables.Membership MembershipTable tables.Membership
PublishedTable tables.Published PublishedTable tables.Published
Purge tables.Purge Purge tables.Purge
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
RegistrationTokensTable tables.RegistrationTokens
} }
// EventDatabase contains all tables needed to work with events // EventDatabase contains all tables needed to work with events
@ -58,10 +57,6 @@ type EventDatabase struct {
RedactionsTable tables.Redactions RedactionsTable tables.Redactions
} }
func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) {
return d.RegistrationTokensTable.RegistrationTokenExists(ctx, nil, token)
}
func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) SupportsConcurrentRoomInputs() bool {
return true return true
} }

View file

@ -19,10 +19,6 @@ type EventJSONPair struct {
EventJSON []byte EventJSON []byte
} }
type RegistrationTokens interface {
RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error)
}
type EventJSON interface { type EventJSON interface {
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error

View file

@ -94,6 +94,7 @@ type ClientUserAPI interface {
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error)
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error

View file

@ -63,6 +63,21 @@ type UserInternalAPI struct {
Updater *DeviceListUpdater Updater *DeviceListUpdater
} }
func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) {
exists, err := a.DB.RegistrationTokenExists(ctx, token)
if err != nil {
return false, err
}
if exists {
return false, fmt.Errorf("token: %s already exists", token)
}
_, err = a.DB.InsertRegistrationToken(ctx, token, usesAllowed, expiryTime)
if err != nil {
return false, fmt.Errorf("Error creating token: %s"+err.Error(), token)
}
return true, nil
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {

View file

@ -30,6 +30,11 @@ import (
"github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/types"
) )
type RegistrationTokens interface {
RegistrationTokenExists(ctx context.Context, token string) (bool, error)
InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error)
}
type Profile interface { type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
@ -144,6 +149,7 @@ type UserDatabase interface {
Pusher Pusher
Statistics Statistics
ThreePID ThreePID
RegistrationTokens
} }
type KeyChangeDatabase interface { type KeyChangeDatabase interface {

View file

@ -0,0 +1,66 @@
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const registrationTokensSchema = `
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
token TEXT PRIMARY KEY,
pending BIGINT,
completed BIGINT,
uses_allowed BIGINT,
expiry_time BIGINT
);
`
const selectTokenSQL = "" +
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
const insertTokenSQL = "" +
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
type registrationTokenStatements struct {
selectTokenStatement *sql.Stmt
insertTokenStatment *sql.Stmt
}
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
s := &registrationTokenStatements{}
_, err := db.Exec(registrationTokensSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectTokenStatement, selectTokenSQL},
{&s.insertTokenStatment, insertTokenSQL},
}.Prepare(db)
}
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
var existingToken string
stmt := s.selectTokenStatement
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error) {
stmt := sqlutil.TxStmt(tx, s.insertTokenStatment)
pending := 0
completed := 0
_, err := stmt.ExecContext(ctx, token, nil, expiryTime, pending, completed)
if err != nil {
return false, err
}
return true, nil
}

View file

@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
return nil, err return nil, err
} }
registationTokensTable, err := NewPostgresRegistrationTokensTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName) accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
ThreePIDs: threePIDTable, ThreePIDs: threePIDTable,
Pushers: pusherTable, Pushers: pusherTable,
Notifications: notificationsTable, Notifications: notificationsTable,
RegistrationTokens: registationTokensTable,
Stats: statsTable, Stats: statsTable,
ServerName: serverName, ServerName: serverName,
DB: db, DB: db,

View file

@ -43,6 +43,7 @@ import (
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer Writer sqlutil.Writer
RegistrationTokens tables.RegistrationTokensTable
Accounts tables.AccountsTable Accounts tables.AccountsTable
Profiles tables.ProfileTable Profiles tables.ProfileTable
AccountDatas tables.AccountDataTable AccountDatas tables.AccountDataTable
@ -78,6 +79,18 @@ const (
loginTokenByteLength = 32 loginTokenByteLength = 32
) )
func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) {
return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token)
}
func (d *Database) InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (created bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, token, usesAllowed, expiryTime)
return err
})
return
}
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart. // Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword( func (d *Database) GetAccountByPassword(

View file

@ -29,6 +29,11 @@ import (
"github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/types"
) )
type RegistrationTokensTable interface {
RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error)
InsertRegistrationToken(ctx context.Context, txn *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error)
}
type AccountDataTable interface { type AccountDataTable interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)