Three PID

This commit is contained in:
Neil Alexander 2022-11-07 15:29:51 +00:00
parent e36decf025
commit 62dd0afc0b
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 56 additions and 40 deletions

View file

@ -650,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct {
} }
type QueryLocalpartForThreePIDResponse struct { type QueryLocalpartForThreePIDResponse struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
} }
type QueryThreePIDsForLocalpartRequest struct { type QueryThreePIDsForLocalpartRequest struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
} }
type QueryThreePIDsForLocalpartResponse struct { type QueryThreePIDsForLocalpartResponse struct {
@ -664,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct {
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
type PerformSaveThreePIDAssociationRequest struct { type PerformSaveThreePIDAssociationRequest struct {
ThreePID, Localpart, Medium string ThreePID string
Localpart string
ServerName gomatrixserverlib.ServerName
Medium string
} }

View file

@ -928,16 +928,17 @@ func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUp
} }
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
if err != nil { if err != nil {
return err return err
} }
res.Localpart = localpart res.Localpart = localpart
res.ServerName = domain
return nil return nil
} }
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart) r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -950,7 +951,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
} }
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium) return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
} }
const pushRulesAccountDataType = "m.push_rules" const pushRulesAccountDataType = "m.push_rules"

View file

@ -114,10 +114,10 @@ type Pusher interface {
} }
type ThreePID interface { type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
} }
type Notification interface { type Notification interface {

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -42,13 +43,13 @@ CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
` `
const selectLocalpartForThreePIDSQL = "" + const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" + const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
const insertThreePIDSQL = "" + const insertThreePIDSQL = "" +
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
const deleteThreePIDSQL = "" + const deleteThreePIDSQL = "" +
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
@ -76,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID( func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string, ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", "", nil
} }
return return
} }
func (s *threepidStatements) SelectThreePIDsForLocalpart( func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return return
} }
@ -110,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
} }
func (s *threepidStatements) InsertThreePID( func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, txn *sql.Tx, threepid, medium,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart) _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return return
} }

View file

@ -288,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// If the third-party identifier is already part of an association, returns Err3PIDInUse. // If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation( func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string, ctx context.Context, threepid string,
localpart string, serverName gomatrixserverlib.ServerName,
medium string,
) (err error) { ) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
user, err := d.ThreePIDs.SelectLocalpartForThreePID( user, domain, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, threepid, medium, ctx, txn, threepid, medium,
) )
if err != nil { if err != nil {
@ -302,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
return Err3PIDInUse return Err3PIDInUse
} }
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, domain)
}) })
} }
@ -325,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID( func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string, ctx context.Context, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
} }
@ -334,9 +336,10 @@ func (d *Database) GetLocalpartForThreePID(
// If no association is known for this user, returns an empty slice. // If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database. // Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart( func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
} }
// CheckAccountAvailability checks if the username/localpart is already present // CheckAccountAvailability checks if the username/localpart is already present

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -80,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID( func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string, ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", "", nil
} }
return return
} }
func (s *threepidStatements) SelectThreePIDsForLocalpart( func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return return
} }
@ -114,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
} }
func (s *threepidStatements) InsertThreePID( func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, txn *sql.Tx, threepid, medium,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart) _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return err return err
} }

View file

@ -461,7 +461,7 @@ func Test_Pusher(t *testing.T) {
func Test_ThreePID(t *testing.T) { func Test_ThreePID(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -469,15 +469,16 @@ func Test_ThreePID(t *testing.T) {
defer close() defer close()
threePID := util.RandomString(8) threePID := util.RandomString(8)
medium := util.RandomString(8) medium := util.RandomString(8)
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium) err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium)
assert.NoError(t, err, "unable to save threepid association") assert.NoError(t, err, "unable to save threepid association")
// get the stored threepid // get the stored threepid
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium) gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
assert.NoError(t, err, "unable to get localpart for threepid") assert.NoError(t, err, "unable to get localpart for threepid")
assert.Equal(t, aliceLocalpart, gotLocalpart) assert.Equal(t, aliceLocalpart, gotLocalpart)
assert.Equal(t, aliceDomain, gotDomain)
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart") assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 1, len(threepids)) assert.Equal(t, 1, len(threepids))
assert.Equal(t, authtypes.ThreePID{ assert.Equal(t, authtypes.ThreePID{
@ -490,7 +491,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err, "unexpected error") assert.NoError(t, err, "unexpected error")
// verify it was deleted // verify it was deleted
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart") assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 0, len(threepids)) assert.Equal(t, 0, len(threepids))
}) })

View file

@ -92,9 +92,9 @@ type ProfileTable interface {
} }
type ThreePIDTable interface { type ThreePIDTable interface {
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
} }