diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 3617647e9..44d643288 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -248,9 +248,9 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") // and a local Matrix user (identified by the user's ID's local part). // If the third-party identifier is already part of an association, returns Err3PIDInUse. // Returns an error if there was a problem talking to the database. -func (d *Database) SaveThreePIDAssociation(threepid string, localpart string) (err error) { +func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, medium string) (err error) { return common.WithTransaction(d.db, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID(txn, threepid) + user, err := d.threepids.selectLocalpartForThreePID(txn, threepid, medium) if err != nil { return err } @@ -259,7 +259,7 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string) (e return Err3PIDInUse } - return d.threepids.insertThreePID(txn, threepid, localpart) + return d.threepids.insertThreePID(txn, threepid, medium, localpart) }) } @@ -267,8 +267,8 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string) (e // identifier. // If no association exists involving this third-party identifier, returns nothing. // If there was a problem talking to the database, returns an error. -func (d *Database) RemoveThreePIDAssociation(threepid string) (err error) { - return d.threepids.deleteThreePID(threepid) +func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (err error) { + return d.threepids.deleteThreePID(threepid, medium) } // GetLocalpartForThreePID looks up the localpart associated with a given third-party @@ -276,8 +276,8 @@ func (d *Database) RemoveThreePIDAssociation(threepid string) (err error) { // If no association involves the given third-party idenfitier, returns an empty // string. // Returns an error if there was a problem talking to the database. -func (d *Database) GetLocalpartForThreePID(threepid string) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(nil, threepid) +func (d *Database) GetLocalpartForThreePID(threepid string, medium string) (localpart string, err error) { + return d.threepids.selectLocalpartForThreePID(nil, threepid, medium) } // GetThreePIDsForLocalpart looks up the third-party identifiers associated with diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/threepid_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/threepid_table.go index ba40aaebd..648102337 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/threepid_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/threepid_table.go @@ -35,16 +35,16 @@ CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localp ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM account_threepid WHERE threepid = $1" + "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" const insertThreePIDSQL = "" + - "INSERT INTO account_threepid (threepid, localpart) VALUES ($1, $2)" + "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" const deleteThreePIDSQL = "" + - "DELETE FROM account_threepid WHERE threepid = $1" + "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" type threepidStatements struct { selectLocalpartForThreePIDStmt *sql.Stmt @@ -74,14 +74,14 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { return } -func (s *threepidStatements) selectLocalpartForThreePID(txn *sql.Tx, threepid string) (localpart string, err error) { +func (s *threepidStatements) selectLocalpartForThreePID(txn *sql.Tx, threepid string, medium string) (localpart string, err error) { var stmt *sql.Stmt if txn != nil { stmt = txn.Stmt(s.selectLocalpartForThreePIDStmt) } else { stmt = s.selectLocalpartForThreePIDStmt } - err = stmt.QueryRow(threepid).Scan(&localpart) + err = stmt.QueryRow(threepid, medium).Scan(&localpart) if err == sql.ErrNoRows { return "", nil } @@ -107,12 +107,12 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (thre return } -func (s *threepidStatements) insertThreePID(txn *sql.Tx, threepid string, localpart string) (err error) { - _, err = txn.Stmt(s.insertThreePIDStmt).Exec(threepid, localpart) +func (s *threepidStatements) insertThreePID(txn *sql.Tx, threepid string, medium string, localpart string) (err error) { + _, err = txn.Stmt(s.insertThreePIDStmt).Exec(threepid, medium, localpart) return } -func (s *threepidStatements) deleteThreePID(threepid string) (err error) { - _, err = s.deleteThreePIDStmt.Exec(threepid) +func (s *threepidStatements) deleteThreePID(threepid string, medium string) (err error) { + _, err = s.deleteThreePIDStmt.Exec(threepid, medium) return } diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/threepid.go b/src/github.com/matrix-org/dendrite/clientapi/readers/threepid.go index 924917676..bb78baf39 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/threepid.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/threepid.go @@ -40,10 +40,10 @@ type threePID struct { Address string `json:"address"` } -// Request3PIDToken implements: +// RequestEmailToken implements: // POST /account/3pid/email/requestToken // POST /register/email/requestToken -func Request3PIDToken(req *http.Request, accountDB *accounts.Database) util.JSONResponse { +func RequestEmailToken(req *http.Request, accountDB *accounts.Database) util.JSONResponse { var body threepid.EmailAssociationRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr @@ -53,7 +53,7 @@ func Request3PIDToken(req *http.Request, accountDB *accounts.Database) util.JSON var err error // Check if the 3PID is already in use locally - localpart, err := accountDB.GetLocalpartForThreePID(body.Email) + localpart, err := accountDB.GetLocalpartForThreePID(body.Email, "email") if err != nil { return httputil.LogThenError(req, err) } @@ -88,7 +88,7 @@ func CheckAndSave3PIDAssociation( return *reqErr } - verified, address, err := threepid.CheckAssociation(body.Creds) + verified, address, medium, err := threepid.CheckAssociation(body.Creds) if err != nil { return httputil.LogThenError(req, err) } @@ -108,7 +108,7 @@ func CheckAndSave3PIDAssociation( return httputil.LogThenError(req, err) } - if err = accountDB.SaveThreePIDAssociation(address, localpart); err != nil { + if err = accountDB.SaveThreePIDAssociation(address, localpart, medium); err != nil { return httputil.LogThenError(req, err) } @@ -152,7 +152,7 @@ func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONRespon return *reqErr } - if err := accountDB.RemoveThreePIDAssociation(body.Address); err != nil { + if err := accountDB.RemoveThreePIDAssociation(body.Address, body.Medium); err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go index 744c954ca..950202c58 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go @@ -253,7 +253,7 @@ func Setup( r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", common.MakeAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { - return readers.Request3PIDToken(req, accountDB) + return readers.RequestEmailToken(req, accountDB) }), ).Methods("POST", "OPTIONS") diff --git a/src/github.com/matrix-org/dendrite/clientapi/threepid/threepid.go b/src/github.com/matrix-org/dendrite/clientapi/threepid/threepid.go index 682588c87..92cef2b2d 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/threepid/threepid.go +++ b/src/github.com/matrix-org/dendrite/clientapi/threepid/threepid.go @@ -88,14 +88,14 @@ func CreateSession(req EmailAssociationRequest) (string, error) { // identity server. // Returns a boolean set to true if the association has been validated, false if not. // If the association has been validated, also returns the related third-party -// identifier. +// identifier and its medium. // Returns an error if there was a problem sending the request or decoding the // response, or if the identity server responded with a non-OK status. -func CheckAssociation(creds Credentials) (bool, string, error) { +func CheckAssociation(creds Credentials) (bool, string, string, error) { url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret) resp, err := http.Get(url) if err != nil { - return false, "", err + return false, "", "", err } var respBody struct { @@ -107,14 +107,14 @@ func CheckAssociation(creds Credentials) (bool, string, error) { } if err = json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return false, "", err + return false, "", "", err } if respBody.ErrCode == "M_SESSION_NOT_VALIDATED" { - return false, "", nil + return false, "", "", nil } else if len(respBody.ErrCode) > 0 { - return false, "", errors.New(respBody.Error) + return false, "", "", errors.New(respBody.Error) } - return true, respBody.Address, nil + return true, respBody.Address, respBody.Medium, nil }