Replace UserInUse check by psql query error code validation

This commit is contained in:
Crom (Thibaut CHARLES) 2017-12-11 14:09:15 +01:00
parent 2758f668f6
commit afe51d07e8
No known key found for this signature in database
GPG key ID: 45A3D5F880B9E6D0
2 changed files with 14 additions and 15 deletions

View file

@ -24,7 +24,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" "github.com/lib/pq"
) )
// Database represents an account database // Database represents an account database
@ -118,7 +118,8 @@ func (d *Database) SetDisplayName(
} }
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword string, ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
@ -127,6 +128,12 @@ func (d *Database) CreateAccount(
return nil, err return nil, err
} }
if err := d.profiles.insertProfile(ctx, localpart); err != nil { if err := d.profiles.insertProfile(ctx, localpart); err != nil {
if err, ok := err.(*pq.Error); ok {
if err.Code.Class() == "23" {
// 23 => unique_violation => Account already exists
return nil, nil
}
}
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, localpart, hash) return d.accounts.insertAccount(ctx, localpart, hash)

View file

@ -457,25 +457,17 @@ func completeRegistration(
} }
} }
avail, err := accountDB.CheckAccountAvailability(ctx, username)
if err == nil && !avail {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.UserInUse("Desired user ID is already taken."),
}
} else if err != nil {
return util.JSONResponse{
Code: 500,
JSON: jsonerror.Unknown("Failed to check account availability: " + err.Error()),
}
}
acc, err := accountDB.CreateAccount(ctx, username, password) acc, err := accountDB.CreateAccount(ctx, username, password)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 500, Code: 500,
JSON: jsonerror.Unknown("failed to create account: " + err.Error()), JSON: jsonerror.Unknown("failed to create account: " + err.Error()),
} }
} else if acc == nil {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.UserInUse("Desired user ID is already taken."),
}
} }
token, err := auth.GenerateAccessToken() token, err := auth.GenerateAccessToken()