Three PID
This commit is contained in:
parent
e36decf025
commit
62dd0afc0b
|
@ -651,10 +651,12 @@ type QueryLocalpartForThreePIDRequest struct {
|
||||||
|
|
||||||
type QueryLocalpartForThreePIDResponse struct {
|
type QueryLocalpartForThreePIDResponse struct {
|
||||||
Localpart string
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryThreePIDsForLocalpartRequest struct {
|
type QueryThreePIDsForLocalpartRequest struct {
|
||||||
Localpart string
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryThreePIDsForLocalpartResponse struct {
|
type QueryThreePIDsForLocalpartResponse struct {
|
||||||
|
@ -664,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct {
|
||||||
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
|
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
|
||||||
|
|
||||||
type PerformSaveThreePIDAssociationRequest struct {
|
type PerformSaveThreePIDAssociationRequest struct {
|
||||||
ThreePID, Localpart, Medium string
|
ThreePID string
|
||||||
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
Medium string
|
||||||
}
|
}
|
||||||
|
|
|
@ -928,16 +928,17 @@ func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
|
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
|
||||||
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res.Localpart = localpart
|
res.Localpart = localpart
|
||||||
|
res.ServerName = domain
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
|
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
|
||||||
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
|
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -950,7 +951,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
|
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
|
||||||
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
|
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
|
||||||
}
|
}
|
||||||
|
|
||||||
const pushRulesAccountDataType = "m.push_rules"
|
const pushRulesAccountDataType = "m.push_rules"
|
||||||
|
|
|
@ -114,10 +114,10 @@ type Pusher interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ThreePID interface {
|
type ThreePID interface {
|
||||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error)
|
||||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
|
||||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Notification interface {
|
type Notification interface {
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
)
|
)
|
||||||
|
@ -42,13 +43,13 @@ CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
|
||||||
`
|
`
|
||||||
|
|
||||||
const selectLocalpartForThreePIDSQL = "" +
|
const selectLocalpartForThreePIDSQL = "" +
|
||||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||||
|
|
||||||
const selectThreePIDsForLocalpartSQL = "" +
|
const selectThreePIDsForLocalpartSQL = "" +
|
||||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const insertThreePIDSQL = "" +
|
const insertThreePIDSQL = "" +
|
||||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const deleteThreePIDSQL = "" +
|
const deleteThreePIDSQL = "" +
|
||||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||||
|
@ -76,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||||
|
|
||||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", "", nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -110,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) InsertThreePID(
|
func (s *threepidStatements) InsertThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
ctx context.Context, txn *sql.Tx, threepid, medium,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -288,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
||||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
func (d *Database) SaveThreePIDAssociation(
|
func (d *Database) SaveThreePIDAssociation(
|
||||||
ctx context.Context, threepid, localpart, medium string,
|
ctx context.Context, threepid string,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
medium string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
user, domain, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||||
ctx, txn, threepid, medium,
|
ctx, txn, threepid, medium,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -302,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
|
||||||
return Err3PIDInUse
|
return Err3PIDInUse
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
|
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, domain)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -325,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
func (d *Database) GetLocalpartForThreePID(
|
func (d *Database) GetLocalpartForThreePID(
|
||||||
ctx context.Context, threepid string, medium string,
|
ctx context.Context, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||||
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
|
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,9 +336,10 @@ func (d *Database) GetLocalpartForThreePID(
|
||||||
// If no association is known for this user, returns an empty slice.
|
// If no association is known for this user, returns an empty slice.
|
||||||
// Returns an error if there was an issue talking to the database.
|
// Returns an error if there was an issue talking to the database.
|
||||||
func (d *Database) GetThreePIDsForLocalpart(
|
func (d *Database) GetThreePIDsForLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
|
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckAccountAvailability checks if the username/localpart is already present
|
// CheckAccountAvailability checks if the username/localpart is already present
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
)
|
)
|
||||||
|
@ -80,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||||
|
|
||||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", "", nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -114,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) InsertThreePID(
|
func (s *threepidStatements) InsertThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
ctx context.Context, txn *sql.Tx, threepid, medium,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -461,7 +461,7 @@ func Test_Pusher(t *testing.T) {
|
||||||
|
|
||||||
func Test_ThreePID(t *testing.T) {
|
func Test_ThreePID(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
@ -469,15 +469,16 @@ func Test_ThreePID(t *testing.T) {
|
||||||
defer close()
|
defer close()
|
||||||
threePID := util.RandomString(8)
|
threePID := util.RandomString(8)
|
||||||
medium := util.RandomString(8)
|
medium := util.RandomString(8)
|
||||||
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
|
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium)
|
||||||
assert.NoError(t, err, "unable to save threepid association")
|
assert.NoError(t, err, "unable to save threepid association")
|
||||||
|
|
||||||
// get the stored threepid
|
// get the stored threepid
|
||||||
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
||||||
assert.NoError(t, err, "unable to get localpart for threepid")
|
assert.NoError(t, err, "unable to get localpart for threepid")
|
||||||
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
||||||
|
assert.Equal(t, aliceDomain, gotDomain)
|
||||||
|
|
||||||
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||||
assert.Equal(t, 1, len(threepids))
|
assert.Equal(t, 1, len(threepids))
|
||||||
assert.Equal(t, authtypes.ThreePID{
|
assert.Equal(t, authtypes.ThreePID{
|
||||||
|
@ -490,7 +491,7 @@ func Test_ThreePID(t *testing.T) {
|
||||||
assert.NoError(t, err, "unexpected error")
|
assert.NoError(t, err, "unexpected error")
|
||||||
|
|
||||||
// verify it was deleted
|
// verify it was deleted
|
||||||
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||||
assert.Equal(t, 0, len(threepids))
|
assert.Equal(t, 0, len(threepids))
|
||||||
})
|
})
|
||||||
|
|
|
@ -92,9 +92,9 @@ type ProfileTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ThreePIDTable interface {
|
type ThreePIDTable interface {
|
||||||
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
|
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
|
||||||
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
|
||||||
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
|
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||||
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
|
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue