bugfix: Fix a race condition when creating guest accounts

It was possible to both select the same next numeric ID and then both
attempt to INSERT this into the table. This would cause a UNIQUE violation
which then presented itself as an error in sqlite because it does not
implement `common.IsUniqueConstraintViolationErr`.

The fix here is NOT to implement `common.IsUniqueConstraintViolationErr`
otherwise the 2 users would get the SAME guest account. Instead, all of
these operations should be done inside a transaction. This is what this
PR does.
This commit is contained in:
Kegan Dougal 2020-03-06 15:45:16 +00:00
parent 87283e9de7
commit abacc13a76
7 changed files with 51 additions and 26 deletions

View file

@ -30,6 +30,7 @@ type Database interface {
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error) CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
CreateGuestAccount(ctx context.Context) (*authtypes.Account, error)
UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)

View file

@ -118,6 +118,10 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName) return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
func (d *Database) CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) {
return nil, nil
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil. // account already exists, it will return nil, nil.

View file

@ -72,10 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
) (err error) { ) (err error) {
stmt := s.insertAccountDataStmt _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return return
} }

View file

@ -89,16 +89,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error var err error
if appserviceID == "" { if appserviceID == "" {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -144,8 +144,12 @@ func (s *accountsStatements) selectAccountByLocalpart(
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
return return
} }

View file

@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"strconv"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -118,14 +119,36 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName) return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
numLocalpart, err := d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil. // account already exists, it will return nil, nil.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *authtypes.Account, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
var err error var err error
// Generate a password hash if this is not a password-less user // Generate a password hash if this is not a password-less user
hash := "" hash := ""
if plaintextPassword != "" { if plaintextPassword != "" {
@ -134,13 +157,14 @@ func (d *Database) CreateAccount(
return nil, err return nil, err
} }
} }
if err := d.profiles.insertProfile(ctx, localpart); err != nil { if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) { if common.IsUniqueConstraintViolationErr(err) {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push.rules", `{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -151,7 +175,7 @@ func (d *Database) CreateAccount(
}`); err != nil { }`); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
} }
// SaveMembership saves the user matching a given localpart as a member of a given // SaveMembership saves the user matching a given localpart as a member of a given
@ -258,7 +282,9 @@ func (d *Database) newMembership(
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, localpart, roomID, dataType, content string,
) error { ) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
@ -288,7 +314,7 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart( func (d *Database) GetNewNumericLocalpart(
ctx context.Context, ctx context.Context,
) (int64, error) { ) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx) return d.accounts.selectNewNumericLocalpart(ctx, nil)
} }
func hashPassword(plaintext string) (hash string, err error) { func hashPassword(plaintext string) (hash string, err error) {

View file

@ -516,16 +516,7 @@ func handleGuestRegistration(
accountDB accounts.Database, accountDB accounts.Database,
deviceDB devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
acc, err := accountDB.CreateGuestAccount(req.Context())
//Generate numeric local part for guest user
id, err := accountDB.GetNewNumericLocalpart(req.Context())
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetNewNumericLocalpart failed")
return jsonerror.InternalServerError()
}
localpart := strconv.FormatInt(id, 10)
acc, err := accountDB.CreateAccount(req.Context(), localpart, "", "")
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,