diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts_table.go index 1bed2b069..aa98d8dd0 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts_table.go @@ -40,15 +40,19 @@ CREATE TABLE IF NOT EXISTS accounts ( const insertAccountSQL = "" + "INSERT INTO accounts(localpart, created_ts, password_hash) VALUES ($1, $2, $3)" -const selectAccountByPasswordHashSQL = "" + - "SELECT localpart WHERE localpart = $1 AND password_hash = $2" +const selectAccountByLocalpartSQL = "" + + "SELECT localpart WHERE localpart = $1" + +const selectPasswordHashSQL = "" + + "SELECT password_hash WHERE localpart = $1" // TODO: Update password type accountsStatements struct { - insertAccountStmt *sql.Stmt - selectAccountByPasswordHashStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + insertAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + serverName gomatrixserverlib.ServerName } func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { @@ -59,7 +63,10 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } - if s.selectAccountByPasswordHashStmt, err = db.Prepare(selectAccountByPasswordHashSQL); err != nil { + if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { + return + } + if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { return } s.serverName = server @@ -80,9 +87,14 @@ func (s *accountsStatements) insertAccount(localpart, hash string) (acc *types.A return } -func (s *accountsStatements) selectAccountByPasswordHash(localpart, hash string) (*types.Account, error) { +func (s *accountsStatements) selectPasswordHash(localpart string) (hash string, err error) { + err = s.selectPasswordHashStmt.QueryRow(localpart).Scan(&hash) + return +} + +func (s *accountsStatements) selectAccountByLocalpart(localpart string) (*types.Account, error) { var acc types.Account - err := s.selectAccountByPasswordHashStmt.QueryRow(localpart, hash).Scan(&acc.Localpart) + err := s.selectAccountByLocalpartStmt.QueryRow(localpart).Scan(&acc.Localpart) if err != nil { acc.UserID = makeUserID(localpart, s.serverName) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/storage.go index a236da1d1..22001ca8a 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/storage.go @@ -44,11 +44,14 @@ func NewAccountDatabase(dataSourceName string, serverName gomatrixserverlib.Serv // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given credentials. func (d *AccountDatabase) GetAccountByPassword(localpart, plaintextPassword string) (*types.Account, error) { - hash, err := hashPassword(plaintextPassword) + hash, err := d.accounts.selectPasswordHash(localpart) if err != nil { return nil, err } - return d.accounts.selectAccountByPasswordHash(localpart, hash) + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { + return nil, err + } + return d.accounts.selectAccountByLocalpart(localpart) } // CreateAccount makes a new account with the given login name and password. If no password is supplied,