Use bcrypt correctly

This commit is contained in:
Kegan Dougal 2017-05-18 16:22:27 +01:00
parent 9b48a67e14
commit fa076d3a07
2 changed files with 25 additions and 10 deletions

View file

@ -40,15 +40,19 @@ CREATE TABLE IF NOT EXISTS accounts (
const insertAccountSQL = "" +
"INSERT INTO accounts(localpart, created_ts, password_hash) VALUES ($1, $2, $3)"
const selectAccountByPasswordHashSQL = "" +
"SELECT localpart WHERE localpart = $1 AND password_hash = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash WHERE localpart = $1"
// TODO: Update password
type accountsStatements struct {
insertAccountStmt *sql.Stmt
selectAccountByPasswordHashStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
@ -59,7 +63,10 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.selectAccountByPasswordHashStmt, err = db.Prepare(selectAccountByPasswordHashSQL); err != nil {
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return
}
s.serverName = server
@ -80,9 +87,14 @@ func (s *accountsStatements) insertAccount(localpart, hash string) (acc *types.A
return
}
func (s *accountsStatements) selectAccountByPasswordHash(localpart, hash string) (*types.Account, error) {
func (s *accountsStatements) selectPasswordHash(localpart string) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRow(localpart).Scan(&hash)
return
}
func (s *accountsStatements) selectAccountByLocalpart(localpart string) (*types.Account, error) {
var acc types.Account
err := s.selectAccountByPasswordHashStmt.QueryRow(localpart, hash).Scan(&acc.Localpart)
err := s.selectAccountByLocalpartStmt.QueryRow(localpart).Scan(&acc.Localpart)
if err != nil {
acc.UserID = makeUserID(localpart, s.serverName)
}

View file

@ -44,11 +44,14 @@ func NewAccountDatabase(dataSourceName string, serverName gomatrixserverlib.Serv
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given credentials.
func (d *AccountDatabase) GetAccountByPassword(localpart, plaintextPassword string) (*types.Account, error) {
hash, err := hashPassword(plaintextPassword)
hash, err := d.accounts.selectPasswordHash(localpart)
if err != nil {
return nil, err
}
return d.accounts.selectAccountByPasswordHash(localpart, hash)
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(localpart)
}
// CreateAccount makes a new account with the given login name and password. If no password is supplied,