mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-12 01:13:10 -06:00
Include medium in database requests and function calls
This commit is contained in:
parent
163b213b63
commit
e0cad6dc20
|
|
@ -248,9 +248,9 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use")
|
|||
// and a local Matrix user (identified by the user's ID's local part).
|
||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) SaveThreePIDAssociation(threepid string, localpart string) (err error) {
|
||||
func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, medium string) (err error) {
|
||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
user, err := d.threepids.selectLocalpartForThreePID(txn, threepid)
|
||||
user, err := d.threepids.selectLocalpartForThreePID(txn, threepid, medium)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -259,7 +259,7 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string) (e
|
|||
return Err3PIDInUse
|
||||
}
|
||||
|
||||
return d.threepids.insertThreePID(txn, threepid, localpart)
|
||||
return d.threepids.insertThreePID(txn, threepid, medium, localpart)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -267,8 +267,8 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string) (e
|
|||
// identifier.
|
||||
// If no association exists involving this third-party identifier, returns nothing.
|
||||
// If there was a problem talking to the database, returns an error.
|
||||
func (d *Database) RemoveThreePIDAssociation(threepid string) (err error) {
|
||||
return d.threepids.deleteThreePID(threepid)
|
||||
func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (err error) {
|
||||
return d.threepids.deleteThreePID(threepid, medium)
|
||||
}
|
||||
|
||||
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
||||
|
|
@ -276,8 +276,8 @@ func (d *Database) RemoveThreePIDAssociation(threepid string) (err error) {
|
|||
// If no association involves the given third-party idenfitier, returns an empty
|
||||
// string.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) GetLocalpartForThreePID(threepid string) (localpart string, err error) {
|
||||
return d.threepids.selectLocalpartForThreePID(nil, threepid)
|
||||
func (d *Database) GetLocalpartForThreePID(threepid string, medium string) (localpart string, err error) {
|
||||
return d.threepids.selectLocalpartForThreePID(nil, threepid, medium)
|
||||
}
|
||||
|
||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||
|
|
|
|||
|
|
@ -35,16 +35,16 @@ CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localp
|
|||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM account_threepid WHERE threepid = $1"
|
||||
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO account_threepid (threepid, localpart) VALUES ($1, $2)"
|
||||
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM account_threepid WHERE threepid = $1"
|
||||
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
type threepidStatements struct {
|
||||
selectLocalpartForThreePIDStmt *sql.Stmt
|
||||
|
|
@ -74,14 +74,14 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectLocalpartForThreePID(txn *sql.Tx, threepid string) (localpart string, err error) {
|
||||
func (s *threepidStatements) selectLocalpartForThreePID(txn *sql.Tx, threepid string, medium string) (localpart string, err error) {
|
||||
var stmt *sql.Stmt
|
||||
if txn != nil {
|
||||
stmt = txn.Stmt(s.selectLocalpartForThreePIDStmt)
|
||||
} else {
|
||||
stmt = s.selectLocalpartForThreePIDStmt
|
||||
}
|
||||
err = stmt.QueryRow(threepid).Scan(&localpart)
|
||||
err = stmt.QueryRow(threepid, medium).Scan(&localpart)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
|
@ -107,12 +107,12 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (thre
|
|||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) insertThreePID(txn *sql.Tx, threepid string, localpart string) (err error) {
|
||||
_, err = txn.Stmt(s.insertThreePIDStmt).Exec(threepid, localpart)
|
||||
func (s *threepidStatements) insertThreePID(txn *sql.Tx, threepid string, medium string, localpart string) (err error) {
|
||||
_, err = txn.Stmt(s.insertThreePIDStmt).Exec(threepid, medium, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) deleteThreePID(threepid string) (err error) {
|
||||
_, err = s.deleteThreePIDStmt.Exec(threepid)
|
||||
func (s *threepidStatements) deleteThreePID(threepid string, medium string) (err error) {
|
||||
_, err = s.deleteThreePIDStmt.Exec(threepid, medium)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ type threePID struct {
|
|||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
// Request3PIDToken implements:
|
||||
// RequestEmailToken implements:
|
||||
// POST /account/3pid/email/requestToken
|
||||
// POST /register/email/requestToken
|
||||
func Request3PIDToken(req *http.Request, accountDB *accounts.Database) util.JSONResponse {
|
||||
func RequestEmailToken(req *http.Request, accountDB *accounts.Database) util.JSONResponse {
|
||||
var body threepid.EmailAssociationRequest
|
||||
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
||||
return *reqErr
|
||||
|
|
@ -53,7 +53,7 @@ func Request3PIDToken(req *http.Request, accountDB *accounts.Database) util.JSON
|
|||
var err error
|
||||
|
||||
// Check if the 3PID is already in use locally
|
||||
localpart, err := accountDB.GetLocalpartForThreePID(body.Email)
|
||||
localpart, err := accountDB.GetLocalpartForThreePID(body.Email, "email")
|
||||
if err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
|
@ -88,7 +88,7 @@ func CheckAndSave3PIDAssociation(
|
|||
return *reqErr
|
||||
}
|
||||
|
||||
verified, address, err := threepid.CheckAssociation(body.Creds)
|
||||
verified, address, medium, err := threepid.CheckAssociation(body.Creds)
|
||||
if err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
|
@ -108,7 +108,7 @@ func CheckAndSave3PIDAssociation(
|
|||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
if err = accountDB.SaveThreePIDAssociation(address, localpart); err != nil {
|
||||
if err = accountDB.SaveThreePIDAssociation(address, localpart, medium); err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
|
|
@ -152,7 +152,7 @@ func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONRespon
|
|||
return *reqErr
|
||||
}
|
||||
|
||||
if err := accountDB.RemoveThreePIDAssociation(body.Address); err != nil {
|
||||
if err := accountDB.RemoveThreePIDAssociation(body.Address, body.Medium); err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ func Setup(
|
|||
|
||||
r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
|
||||
common.MakeAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse {
|
||||
return readers.Request3PIDToken(req, accountDB)
|
||||
return readers.RequestEmailToken(req, accountDB)
|
||||
}),
|
||||
).Methods("POST", "OPTIONS")
|
||||
|
||||
|
|
|
|||
|
|
@ -88,14 +88,14 @@ func CreateSession(req EmailAssociationRequest) (string, error) {
|
|||
// identity server.
|
||||
// Returns a boolean set to true if the association has been validated, false if not.
|
||||
// If the association has been validated, also returns the related third-party
|
||||
// identifier.
|
||||
// identifier and its medium.
|
||||
// Returns an error if there was a problem sending the request or decoding the
|
||||
// response, or if the identity server responded with a non-OK status.
|
||||
func CheckAssociation(creds Credentials) (bool, string, error) {
|
||||
func CheckAssociation(creds Credentials) (bool, string, string, error) {
|
||||
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret)
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
return false, "", "", err
|
||||
}
|
||||
|
||||
var respBody struct {
|
||||
|
|
@ -107,14 +107,14 @@ func CheckAssociation(creds Credentials) (bool, string, error) {
|
|||
}
|
||||
|
||||
if err = json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
|
||||
return false, "", err
|
||||
return false, "", "", err
|
||||
}
|
||||
|
||||
if respBody.ErrCode == "M_SESSION_NOT_VALIDATED" {
|
||||
return false, "", nil
|
||||
return false, "", "", nil
|
||||
} else if len(respBody.ErrCode) > 0 {
|
||||
return false, "", errors.New(respBody.Error)
|
||||
return false, "", "", errors.New(respBody.Error)
|
||||
}
|
||||
|
||||
return true, respBody.Address, nil
|
||||
return true, respBody.Address, respBody.Medium, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue