Three PID

This commit is contained in:
Neil Alexander 2022-11-07 15:29:51 +00:00
parent e36decf025
commit 62dd0afc0b
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 56 additions and 40 deletions

View file

@ -650,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct {
}
type QueryLocalpartForThreePIDResponse struct {
Localpart string
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryThreePIDsForLocalpartRequest struct {
Localpart string
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryThreePIDsForLocalpartResponse struct {
@ -664,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct {
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
type PerformSaveThreePIDAssociationRequest struct {
ThreePID, Localpart, Medium string
ThreePID string
Localpart string
ServerName gomatrixserverlib.ServerName
Medium string
}

View file

@ -928,16 +928,17 @@ func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUp
}
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
if err != nil {
return err
}
res.Localpart = localpart
res.ServerName = domain
return nil
}
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
if err != nil {
return err
}
@ -950,7 +951,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
}
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
}
const pushRulesAccountDataType = "m.push_rules"

View file

@ -114,10 +114,10 @@ type Pusher interface {
}
type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
}
type Notification interface {

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@ -42,13 +43,13 @@ CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
const insertThreePIDSQL = "" +
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
const deleteThreePIDSQL = "" +
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
@ -76,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows {
return "", nil
return "", "", nil
}
return
}
func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return
}
@ -110,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
}
func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
ctx context.Context, txn *sql.Tx, threepid, medium,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return
}

View file

@ -288,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
ctx context.Context, threepid string,
localpart string, serverName gomatrixserverlib.ServerName,
medium string,
) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
user, domain, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
@ -302,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
return Err3PIDInUse
}
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, domain)
})
}
@ -325,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
}
@ -334,9 +336,10 @@ func (d *Database) GetLocalpartForThreePID(
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
}
// CheckAccountAvailability checks if the username/localpart is already present

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@ -80,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows {
return "", nil
return "", "", nil
}
return
}
func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return
}
@ -114,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
}
func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
ctx context.Context, txn *sql.Tx, threepid, medium,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return err
}

View file

@ -461,7 +461,7 @@ func Test_Pusher(t *testing.T) {
func Test_ThreePID(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -469,15 +469,16 @@ func Test_ThreePID(t *testing.T) {
defer close()
threePID := util.RandomString(8)
medium := util.RandomString(8)
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium)
assert.NoError(t, err, "unable to save threepid association")
// get the stored threepid
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
assert.NoError(t, err, "unable to get localpart for threepid")
assert.Equal(t, aliceLocalpart, gotLocalpart)
assert.Equal(t, aliceDomain, gotDomain)
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 1, len(threepids))
assert.Equal(t, authtypes.ThreePID{
@ -490,7 +491,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err, "unexpected error")
// verify it was deleted
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 0, len(threepids))
})

View file

@ -92,9 +92,9 @@ type ProfileTable interface {
}
type ThreePIDTable interface {
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
}