Merge branch 'master' into neilalexander/stateresv2

This commit is contained in:
Neil Alexander 2020-03-10 11:46:36 +00:00
commit ea2fa4a401
13 changed files with 119 additions and 40 deletions

6
.gitignore vendored
View file

@ -43,3 +43,9 @@ _testmain.go
# Default configuration file # Default configuration file
dendrite.yaml dendrite.yaml
# Database files
*.db
# Log files
*.log*

View file

@ -30,6 +30,7 @@ type Database interface {
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error) CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
CreateGuestAccount(ctx context.Context) (*authtypes.Account, error)
UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)

View file

@ -72,9 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
) (err error) { ) (err error) {
stmt := s.insertAccountDataStmt stmt := txn.Stmt(s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return return
} }

View file

@ -91,10 +91,10 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := txn.Stmt(s.insertAccountStmt)
var err error var err error
if appserviceID == "" { if appserviceID == "" {
@ -146,8 +146,12 @@ func (s *accountsStatements) selectAccountByLocalpart(
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
return return
} }

View file

@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"strconv"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -118,11 +119,37 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName) return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil. // account already exists, it will return nil, nil.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
var err error var err error
@ -134,13 +161,14 @@ func (d *Database) CreateAccount(
return nil, err return nil, err
} }
} }
if err := d.profiles.insertProfile(ctx, localpart); err != nil { if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) { if common.IsUniqueConstraintViolationErr(err) {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -151,7 +179,7 @@ func (d *Database) CreateAccount(
}`); err != nil { }`); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
} }
// SaveMembership saves the user matching a given localpart as a member of a given // SaveMembership saves the user matching a given localpart as a member of a given
@ -258,7 +286,9 @@ func (d *Database) newMembership(
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, localpart, roomID, dataType, content string,
) error { ) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
@ -288,7 +318,7 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart( func (d *Database) GetNewNumericLocalpart(
ctx context.Context, ctx context.Context,
) (int64, error) { ) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx) return d.accounts.selectNewNumericLocalpart(ctx, nil)
} }
func hashPassword(plaintext string) (hash string, err error) { func hashPassword(plaintext string) (hash string, err error) {

View file

@ -72,10 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
) (err error) { ) (err error) {
stmt := s.insertAccountDataStmt _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return return
} }

View file

@ -89,16 +89,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error var err error
if appserviceID == "" { if appserviceID == "" {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -144,8 +144,12 @@ func (s *accountsStatements) selectAccountByLocalpart(
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
return return
} }

View file

@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"strconv"
"sync"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -39,6 +41,8 @@ type Database struct {
threepids threepidStatements threepids threepidStatements
filter filterStatements filter filterStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
createGuestAccountMu sync.Mutex
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
@ -76,7 +80,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if err = f.prepare(db); err != nil { if err = f.prepare(db); err != nil {
return nil, err return nil, err
} }
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil
} }
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
@ -118,14 +122,46 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName) return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
// We need to lock so we sequentially create numeric localparts. If we don't, two calls to
// this function will cause the same number to be selected and one will fail with 'database is locked'
// when the first txn upgrades to a write txn.
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
d.createGuestAccountMu.Lock()
defer d.createGuestAccountMu.Unlock()
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil. // account already exists, it will return nil, nil.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
var err error var err error
// Generate a password hash if this is not a password-less user // Generate a password hash if this is not a password-less user
hash := "" hash := ""
if plaintextPassword != "" { if plaintextPassword != "" {
@ -134,13 +170,14 @@ func (d *Database) CreateAccount(
return nil, err return nil, err
} }
} }
if err := d.profiles.insertProfile(ctx, localpart); err != nil { if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) { if common.IsUniqueConstraintViolationErr(err) {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -151,7 +188,7 @@ func (d *Database) CreateAccount(
}`); err != nil { }`); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
} }
// SaveMembership saves the user matching a given localpart as a member of a given // SaveMembership saves the user matching a given localpart as a member of a given
@ -258,7 +295,9 @@ func (d *Database) newMembership(
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, localpart, roomID, dataType, content string,
) error { ) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
@ -288,7 +327,7 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart( func (d *Database) GetNewNumericLocalpart(
ctx context.Context, ctx context.Context,
) (int64, error) { ) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx) return d.accounts.selectNewNumericLocalpart(ctx, nil)
} }
func hashPassword(plaintext string) (hash string, err error) { func hashPassword(plaintext string) (hash string, err error) {

View file

@ -516,16 +516,7 @@ func handleGuestRegistration(
accountDB accounts.Database, accountDB accounts.Database,
deviceDB devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
acc, err := accountDB.CreateGuestAccount(req.Context())
//Generate numeric local part for guest user
id, err := accountDB.GetNewNumericLocalpart(req.Context())
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetNewNumericLocalpart failed")
return jsonerror.InternalServerError()
}
localpart := strconv.FormatInt(id, 10)
acc, err := accountDB.CreateAccount(req.Context(), localpart, "", "")
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,

View file

@ -77,6 +77,7 @@ func SendEvent(
util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed") util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
util.GetLogger(req.Context()).WithField("event_id", eventID).Info("Sent event")
res := util.JSONResponse{ res := util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -25,6 +25,10 @@ func MakeAuthAPI(
if err != nil { if err != nil {
return *err return *err
} }
// add the user ID to the logger
logger := util.GetLogger((req.Context()))
logger = logger.WithField("user_id", device.UserID)
req = req.WithContext(util.ContextWithLogger(req.Context(), logger))
return f(req, device) return f(req, device)
} }