Move DB Layer to UserAPI

This commit is contained in:
santhoshivan23 2023-06-06 21:52:05 +05:30
parent 6cd6af150b
commit fe2464fd4b
16 changed files with 133 additions and 69 deletions

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
func generateRandomToken(length int) string {
@ -38,7 +39,7 @@ func generateRandomToken(length int) string {
return sb.String()
}
func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
if !cfg.RegistrationRequiresToken {
return util.MatrixErrorResponse(
http.StatusForbidden,
@ -67,7 +68,11 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r
length := request.Length
if len(token) == 0 {
// Token not present in request body. Hence, generate a random token.
if length == 0 {
// length not provided in request. Assign default value of 16.
length = 16
}
// token not present in request body. Hence, generate a random token.
if !(length > 0 && length <= 64) {
return util.MatrixErrorResponse(
http.StatusBadRequest,
@ -109,7 +114,9 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r
}
pending := 0
completed := 0
created, err := rsAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, int32(pending), int32(completed), expiryTime)
// If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating
// unlimited uses / no expiration will be persisted in DB)
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, expiryTime)
if err != nil {
return util.MatrixErrorResponse(
http.StatusInternalServerError,

View file

@ -164,7 +164,7 @@ func Setup(
}
dendriteAdminRouter.Handle("/admin/registrationTokens/new",
httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminCreateNewRegistrationToken(req, cfg, rsAPI)
return AdminCreateNewRegistrationToken(req, cfg, userAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)

View file

@ -173,7 +173,6 @@ type ClientRoomserverAPI interface {
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
// PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed, pending, completed int32, expiryTime int64) (bool, error)
PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error)
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformAdminPurgeRoom(ctx context.Context, roomID string) error

View file

@ -42,21 +42,6 @@ type Admin struct {
Leaver *Leaver
}
func (r *Admin) PerformAdminCreateRegistrationToken(
ctx context.Context, token string,
usesAllowed, pending, completed int32,
expiryTime int64) (bool, error) {
exists, err := r.DB.RegistrationTokenExists(ctx, token)
if err != nil {
return false, err
}
if exists {
fmt.Println(fmt.Sprintf("token: %s already exists", token))
return false, fmt.Errorf("token: %s already exists", token)
}
return true, nil
}
// PerformAdminEvacuateRoom will remove all local users from the given room.
func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context,

View file

@ -27,7 +27,6 @@ import (
)
type Database interface {
RegistrationTokenExists(ctx context.Context, token string) (bool, error)
// Do we support processing input events for more than one room at a time?
SupportsConcurrentRoomInputs() bool
// RoomInfo returns room information for the given room ID, or nil if there is no room.

View file

@ -1,25 +0,0 @@
package postgres
import (
"context"
"database/sql"
)
const registrationTokensSchema = `
CREATE TABLE IF NOT EXISTS roomserver_registration_tokens (
token TEXT PRIMARY KEY,
pending BIGINT,
completed BIGINT,
uses_allowed BIGINT,
expiry_time BIGINT
);
`
func CreateRegistrationTokensTable(db *sql.DB) error {
_, err := db.Exec(registrationTokensSchema)
return err
}
func RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
return true, nil
}

View file

@ -92,9 +92,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error {
}
func (d *Database) create(db *sql.DB) error {
if err := CreateRegistrationTokensTable(db); err != nil {
return err
}
if err := CreateEventStateKeysTable(db); err != nil {
return err
}

View file

@ -31,18 +31,17 @@ const redactionsArePermanent = true
type Database struct {
DB *sql.DB
EventDatabase
Cache caching.RoomServerCaches
Writer sqlutil.Writer
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
Purge tables.Purge
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
RegistrationTokensTable tables.RegistrationTokens
Cache caching.RoomServerCaches
Writer sqlutil.Writer
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
Purge tables.Purge
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
}
// EventDatabase contains all tables needed to work with events
@ -58,10 +57,6 @@ type EventDatabase struct {
RedactionsTable tables.Redactions
}
func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) {
return d.RegistrationTokensTable.RegistrationTokenExists(ctx, nil, token)
}
func (d *Database) SupportsConcurrentRoomInputs() bool {
return true
}

View file

@ -19,10 +19,6 @@ type EventJSONPair struct {
EventJSON []byte
}
type RegistrationTokens interface {
RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error)
}
type EventJSON interface {
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error

View file

@ -94,6 +94,7 @@ type ClientUserAPI interface {
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error)
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error

View file

@ -63,6 +63,21 @@ type UserInternalAPI struct {
Updater *DeviceListUpdater
}
func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) {
exists, err := a.DB.RegistrationTokenExists(ctx, token)
if err != nil {
return false, err
}
if exists {
return false, fmt.Errorf("token: %s already exists", token)
}
_, err = a.DB.InsertRegistrationToken(ctx, token, usesAllowed, expiryTime)
if err != nil {
return false, fmt.Errorf("Error creating token: %s"+err.Error(), token)
}
return true, nil
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {

View file

@ -30,6 +30,11 @@ import (
"github.com/matrix-org/dendrite/userapi/types"
)
type RegistrationTokens interface {
RegistrationTokenExists(ctx context.Context, token string) (bool, error)
InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error)
}
type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
@ -144,6 +149,7 @@ type UserDatabase interface {
Pusher
Statistics
ThreePID
RegistrationTokens
}
type KeyChangeDatabase interface {

View file

@ -0,0 +1,66 @@
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const registrationTokensSchema = `
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
token TEXT PRIMARY KEY,
pending BIGINT,
completed BIGINT,
uses_allowed BIGINT,
expiry_time BIGINT
);
`
const selectTokenSQL = "" +
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
const insertTokenSQL = "" +
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
type registrationTokenStatements struct {
selectTokenStatement *sql.Stmt
insertTokenStatment *sql.Stmt
}
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
s := &registrationTokenStatements{}
_, err := db.Exec(registrationTokensSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectTokenStatement, selectTokenSQL},
{&s.insertTokenStatment, insertTokenSQL},
}.Prepare(db)
}
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
var existingToken string
stmt := s.selectTokenStatement
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error) {
stmt := sqlutil.TxStmt(tx, s.insertTokenStatment)
pending := 0
completed := 0
_, err := stmt.ExecContext(ctx, token, nil, expiryTime, pending, completed)
if err != nil {
return false, err
}
return true, nil
}

View file

@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
return nil, err
}
registationTokensTable, err := NewPostgresRegistrationTokensTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
ThreePIDs: threePIDTable,
Pushers: pusherTable,
Notifications: notificationsTable,
RegistrationTokens: registationTokensTable,
Stats: statsTable,
ServerName: serverName,
DB: db,

View file

@ -43,6 +43,7 @@ import (
type Database struct {
DB *sql.DB
Writer sqlutil.Writer
RegistrationTokens tables.RegistrationTokensTable
Accounts tables.AccountsTable
Profiles tables.ProfileTable
AccountDatas tables.AccountDataTable
@ -78,6 +79,18 @@ const (
loginTokenByteLength = 32
)
func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) {
return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token)
}
func (d *Database) InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (created bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, token, usesAllowed, expiryTime)
return err
})
return
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(

View file

@ -29,6 +29,11 @@ import (
"github.com/matrix-org/dendrite/userapi/types"
)
type RegistrationTokensTable interface {
RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error)
InsertRegistrationToken(ctx context.Context, txn *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error)
}
type AccountDataTable interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)