Add account type (#2171)

* Add account_type for sqlite3

* Add account_type for postgres

* Remove CreateGuestAccount from interface

* Add new AccountTypes & update test

* Use newly added AccountType for account creation

* Add migrations

* Reuse type

* Add AccounnType to Device, so it can be verified on requests

* Rename migration, add missing update for appservices

* Rename sqlite3 migration

* Add missing AccountType to return value

* Update sqlite migration
Change allowance check on /admin/whois

* Fix migration, add IS NULL

* Move accountType to completeRegistration

* Fix migrations

* Add passing test
This commit is contained in:
S7evinK 2022-02-16 18:55:38 +01:00 committed by GitHub
parent e9b672a34e
commit 5a39512f5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 230 additions and 117 deletions

View file

@ -22,6 +22,8 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/sirupsen/logrus"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/consumers"
"github.com/matrix-org/dendrite/appservice/inthttp" "github.com/matrix-org/dendrite/appservice/inthttp"
@ -34,7 +36,6 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/sirupsen/logrus"
) )
// AddInternalRoutes registers HTTP handlers for internal API calls // AddInternalRoutes registers HTTP handlers for internal API calls
@ -121,7 +122,7 @@ func generateAppServiceAccount(
) error { ) error {
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeAppService,
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
AppServiceID: as.ID, AppServiceID: as.ID,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,

View file

@ -47,8 +47,8 @@ func GetAdminWhois(
req *http.Request, userAPI api.UserInternalAPI, device *api.Device, req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string, userID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { allowed := device.AccountType == api.AccountTypeAdmin || userID == device.UserID
// TODO: Still allow if user is admin if !allowed {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("userID does not match the current user"), JSON: jsonerror.Forbidden("userID does not match the current user"),

View file

@ -32,6 +32,12 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
@ -39,11 +45,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
) )
var ( var (
@ -701,7 +702,7 @@ func handleApplicationServiceRegistration(
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(),
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
) )
} }
@ -720,7 +721,7 @@ func checkAndCompleteFlow(
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(),
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
) )
} }
@ -745,6 +746,7 @@ func completeRegistration(
username, password, appserviceID, ipAddr, userAgent string, username, password, appserviceID, ipAddr, userAgent string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType,
) util.JSONResponse { ) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
@ -759,13 +761,12 @@ func completeRegistration(
JSON: jsonerror.BadJSON("missing password"), JSON: jsonerror.BadJSON("missing password"),
} }
} }
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AppServiceID: appserviceID, AppServiceID: appserviceID,
Localpart: username, Localpart: username,
Password: password, Password: password,
AccountType: userapi.AccountTypeUser, AccountType: accType,
OnConflict: userapi.ConflictAbort, OnConflict: userapi.ConflictAbort,
}, &accRes) }, &accRes)
if err != nil { if err != nil {
@ -963,5 +964,10 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
return *resErr return *resErr
} }
deviceID := "shared_secret_registration" deviceID := "shared_secret_registration"
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID)
accType := userapi.AccountTypeUser
if ssrr.Admin {
accType = userapi.AccountTypeAdmin
}
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType)
} }

View file

@ -23,12 +23,14 @@ import (
"os" "os"
"strings" "strings"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/term" "golang.org/x/term"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
) )
const usage = `Usage: %s const usage = `Usage: %s
@ -57,6 +59,7 @@ var (
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
askPass = flag.Bool("ask-pass", false, "Ask for the password to use") askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
isAdmin = flag.Bool("admin", false, "Create an admin account")
) )
func main() { func main() {
@ -81,7 +84,11 @@ func main() {
logrus.Fatalln("Failed to connect to the database:", err.Error()) logrus.Fatalln("Failed to connect to the database:", err.Error())
} }
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "") accType := api.AccountTypeUser
if *isAdmin {
accType = api.AccountTypeAdmin
}
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType)
if err != nil { if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error()) logrus.Fatalln("Failed to create the account:", err.Error())
} }

View file

@ -592,3 +592,4 @@ Forward extremities remain so even after the next events are populated as outlie
If a device list update goes missing, the server resyncs on the next one If a device list update goes missing, the server resyncs on the next one
uploading self-signing key notifies over federation uploading self-signing key notifies over federation
uploading signed devices gets propagated over federation uploading signed devices gets propagated over federation
Device list doesn't change if remote server is down

View file

@ -18,8 +18,9 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
// UserInternalAPI is the internal API for information about users and devices. // UserInternalAPI is the internal API for information about users and devices.
@ -353,6 +354,7 @@ type Device struct {
// If the device is for an appservice user, // If the device is for an appservice user,
// this is the appservice ID. // this is the appservice ID.
AppserviceID string AppserviceID string
AccountType AccountType
} }
// Account represents a Matrix account on this home server. // Account represents a Matrix account on this home server.
@ -361,7 +363,7 @@ type Account struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
AppServiceID string AppServiceID string
// TODO: Other flags like IsAdmin, IsGuest AccountType AccountType
// TODO: Associations (e.g. with application services) // TODO: Associations (e.g. with application services)
} }
@ -417,4 +419,8 @@ const (
AccountTypeUser AccountType = 1 AccountTypeUser AccountType = 1
// AccountTypeGuest indicates this is a guest account // AccountTypeGuest indicates this is a guest account
AccountTypeGuest AccountType = 2 AccountTypeGuest AccountType = 2
// AccountTypeAdmin indicates this is an admin account
AccountTypeAdmin AccountType = 3
// AccountTypeAppService indicates this is an appservice account
AccountTypeAppService AccountType = 4
) )

View file

@ -21,6 +21,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -29,9 +33,6 @@ import (
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
type UserInternalAPI struct { type UserInternalAPI struct {
@ -58,16 +59,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
} }
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
if req.AccountType == api.AccountTypeGuest { acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
acc, err := a.AccountDB.CreateGuestAccount(ctx)
if err != nil {
return err
}
res.AccountCreated = true
res.Account = acc
return nil
}
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict { switch req.OnConflict {
@ -86,10 +78,17 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
Localpart: req.Localpart, Localpart: req.Localpart,
ServerName: a.ServerName, ServerName: a.ServerName,
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
AccountType: req.AccountType,
} }
return nil return nil
} }
if req.AccountType == api.AccountTypeGuest {
res.AccountCreated = true
res.Account = acc
return nil
}
if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err return err
} }
@ -375,6 +374,15 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
} }
return err return err
} }
localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return err
}
acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
}
device.AccountType = acc.AccountType
res.Device = device res.Device = device
return nil return nil
} }
@ -401,6 +409,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// AS dummy device has AS's token. // AS dummy device has AS's token.
AccessToken: token, AccessToken: token,
AppserviceID: appService.ID, AppserviceID: appService.ID,
AccountType: api.AccountTypeAppService,
} }
localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName)

View file

@ -32,8 +32,7 @@ type Database interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error) CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
CreateGuestAccount(ctx context.Context) (*api.Account, error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given

View file

@ -19,10 +19,11 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -39,16 +40,18 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- Identifies which application service this account belongs to, if any. -- Identifies which application service this account belongs to, if any.
appservice_id TEXT, appservice_id TEXT,
-- If the account is currently active -- If the account is currently active
is_deactivated BOOLEAN DEFAULT FALSE is_deactivated BOOLEAN DEFAULT FALSE,
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type SMALLINT NOT NULL
-- TODO: -- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
-- Create sequence for autogenerated numeric usernames -- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -57,7 +60,7 @@ const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
@ -96,16 +99,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if appserviceID == "" { if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -116,6 +119,7 @@ func (s *accountsStatements) insertAccount(
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName, ServerName: s.serverName,
AppServiceID: appserviceID, AppServiceID: appserviceID,
AccountType: accountType,
}, nil }, nil
} }
@ -147,7 +151,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
var acc api.Account var acc api.Account
stmt := s.selectAccountByLocalpartStmt stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db") log.WithError(err).Error("Unable to retrieve user from the db")

View file

@ -4,12 +4,14 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose" "github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive) goose.AddMigration(UpIsActive, DownIsActive)
goose.AddMigration(UpAddAccountType, DownAddAccountType)
} }
func LoadIsActive(m *sqlutil.Migrations) { func LoadIsActive(m *sqlutil.Migrations) {

View file

@ -0,0 +1,34 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadAddAccountType(m *sqlutil.Migrations) {
m.AddMigration(UpAddAccountType, DownAddAccountType)
}
func UpAddAccountType(tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.Exec(`ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddAccountType(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN account_type;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -23,13 +23,14 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
@ -73,6 +74,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }
@ -155,37 +157,32 @@ func (d *Database) SetPassword(
return d.accounts.updatePassword(ctx, localpart, hash) return d.accounts.updatePassword(ctx, localpart, hash)
} }
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, sqlutil.ErrUserExists. // account already exists, it will return nil, sqlutil.ErrUserExists.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) // For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart = strconv.FormatInt(numLocalpart, 10)
plaintextPassword = ""
appserviceID = ""
}
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
return err return err
}) })
return return
} }
func (d *Database) createAccount( func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
var account *api.Account var account *api.Account
var err error var err error
@ -197,7 +194,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
if sqlutil.IsUniqueConstraintViolationErr(err) { if sqlutil.IsUniqueConstraintViolationErr(err) {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }

View file

@ -19,10 +19,11 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -39,14 +40,16 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- Identifies which application service this account belongs to, if any. -- Identifies which application service this account belongs to, if any.
appservice_id TEXT, appservice_id TEXT,
-- If the account is currently active -- If the account is currently active
is_deactivated BOOLEAN DEFAULT 0 is_deactivated BOOLEAN DEFAULT 0,
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type INTEGER NOT NULL
-- TODO: -- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -55,7 +58,7 @@ const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
@ -96,16 +99,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error var err error
if appserviceID == "" { if accountType != api.AccountTypeAppService {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
} else { } else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -147,7 +150,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
var acc api.Account var acc api.Account
stmt := s.selectAccountByLocalpartStmt stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db") log.WithError(err).Error("Unable to retrieve user from the db")

View file

@ -4,12 +4,14 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose" "github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive) goose.AddMigration(UpIsActive, DownIsActive)
goose.AddMigration(UpAddAccountType, DownAddAccountType)
} }
func LoadIsActive(m *sqlutil.Migrations) { func LoadIsActive(m *sqlutil.Migrations) {

View file

@ -0,0 +1,54 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func init() {
goose.AddMigration(UpAddAccountType, DownAddAccountType)
}
func LoadAddAccountType(m *sqlutil.Migrations) {
m.AddMigration(UpAddAccountType, DownAddAccountType)
}
func UpAddAccountType(tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.Exec(`ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,
is_deactivated BOOLEAN DEFAULT 0,
account_type INTEGER NOT NULL
);
INSERT
INTO account_accounts (
localpart, created_ts, password_hash, appservice_id, account_type
) SELECT
localpart, created_ts, password_hash, appservice_id, 1
FROM account_accounts_tmp
;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*';
DROP TABLE account_accounts_tmp;`)
if err != nil {
return fmt.Errorf("failed to add column: %w", err)
}
return nil
}
func DownAddAccountType(tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE account_accounts DROP COLUMN account_type;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -24,13 +24,14 @@ import (
"sync" "sync"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
) )
// Database represents an account database // Database represents an account database
@ -77,6 +78,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }
@ -170,38 +172,11 @@ func (d *Database) SetPassword(
}) })
} }
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {
// We need to lock so we sequentially create numeric localparts. If we don't, two calls to
// this function will cause the same number to be selected and one will fail with 'database is locked'
// when the first txn upgrades to a write txn. We also need to lock the account creation else we can
// race with CreateAccount
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
d.profilesMu.Lock()
d.accountDatasMu.Lock()
d.accountsMu.Lock()
defer d.profilesMu.Unlock()
defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'. // Create one account at a time else we can get 'database is locked'.
d.profilesMu.Lock() d.profilesMu.Lock()
@ -211,7 +186,18 @@ func (d *Database) CreateAccount(
defer d.accountDatasMu.Unlock() defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock() defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) // For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart = strconv.FormatInt(numLocalpart, 10)
plaintextPassword = ""
appserviceID = ""
}
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
return err return err
}) })
return return
@ -220,7 +206,7 @@ func (d *Database) CreateAccount(
// WARNING! This function assumes that the relevant mutexes have already // WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount( func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
var err error var err error
var account *api.Account var account *api.Account
@ -232,7 +218,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {

View file

@ -20,10 +20,11 @@ package accounts
import ( import (
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
) )
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)

View file

@ -23,6 +23,9 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -30,8 +33,6 @@ import (
"github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
) )
const ( const (
@ -73,7 +74,7 @@ func TestQueryProfile(t *testing.T) {
aliceAvatarURL := "mxc://example.com/alice" aliceAvatarURL := "mxc://example.com/alice"
aliceDisplayName := "Alice" aliceDisplayName := "Alice"
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }
@ -151,7 +152,7 @@ func TestLoginToken(t *testing.T) {
t.Run("tokenLoginFlow", func(t *testing.T) { t.Run("tokenLoginFlow", func(t *testing.T) {
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "") _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }