From 62dd0afc0b4b998d3d62c465fabb76e2928ded3e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 7 Nov 2022 15:29:51 +0000 Subject: [PATCH] Three PID --- userapi/api/api.go | 11 ++++++++--- userapi/internal/api.go | 7 ++++--- userapi/storage/interface.go | 6 +++--- userapi/storage/postgres/threepid_table.go | 23 ++++++++++++---------- userapi/storage/shared/storage.go | 15 ++++++++------ userapi/storage/sqlite3/threepid_table.go | 17 +++++++++------- userapi/storage/storage_test.go | 11 ++++++----- userapi/storage/tables/interface.go | 6 +++--- 8 files changed, 56 insertions(+), 40 deletions(-) diff --git a/userapi/api/api.go b/userapi/api/api.go index 8b82abd09..d3f5aefc8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -650,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct { } type QueryLocalpartForThreePIDResponse struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryThreePIDsForLocalpartRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryThreePIDsForLocalpartResponse struct { @@ -664,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct { type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest type PerformSaveThreePIDAssociationRequest struct { - ThreePID, Localpart, Medium string + ThreePID string + Localpart string + ServerName gomatrixserverlib.ServerName + Medium string } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 20356751d..2618be094 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -928,16 +928,17 @@ func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUp } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { - localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) + localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) if err != nil { return err } res.Localpart = localpart + res.ServerName = domain return nil } func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { - r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart) + r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName) if err != nil { return err } @@ -950,7 +951,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { - return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium) + return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) } const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index eb4a80272..c22b7658f 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -114,10 +114,10 @@ type Pusher interface { } type ThreePID interface { - SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) + SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) - GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) - GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) } type Notification interface { diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 593e53148..8578e018e 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -42,13 +43,13 @@ CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" + "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2" const insertThreePIDSQL = "" + - "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)" const deleteThreePIDSQL = "" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" @@ -76,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { - return "", nil + return "", "", nil } return } func (s *threepidStatements) SelectThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { return } @@ -110,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( } func (s *threepidStatements) InsertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, txn *sql.Tx, threepid, medium, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) return } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index bb3479b0e..2a0edd81f 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -288,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use") // If the third-party identifier is already part of an association, returns Err3PIDInUse. // Returns an error if there was a problem talking to the database. func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, + ctx context.Context, threepid string, + localpart string, serverName gomatrixserverlib.ServerName, + medium string, ) (err error) { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - user, err := d.ThreePIDs.SelectLocalpartForThreePID( + user, domain, err := d.ThreePIDs.SelectLocalpartForThreePID( ctx, txn, threepid, medium, ) if err != nil { @@ -302,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation( return Err3PIDInUse } - return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) + return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, domain) }) } @@ -325,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation( // Returns an error if there was a problem talking to the database. func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) } @@ -334,9 +336,10 @@ func (d *Database) GetLocalpartForThreePID( // If no association is known for this user, returns an empty slice. // Returns an error if there was an issue talking to the database. func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) + return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName) } // CheckAccountAvailability checks if the username/localpart is already present diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 683d8ecad..29b0eb832 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -80,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { - return "", nil + return "", "", nil } return } func (s *threepidStatements) SelectThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { return } @@ -114,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( } func (s *threepidStatements) InsertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, txn *sql.Tx, threepid, medium, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) return err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 00781309e..79af7b853 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -461,7 +461,7 @@ func Test_Pusher(t *testing.T) { func Test_ThreePID(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -469,15 +469,16 @@ func Test_ThreePID(t *testing.T) { defer close() threePID := util.RandomString(8) medium := util.RandomString(8) - err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium) + err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium) assert.NoError(t, err, "unable to save threepid association") // get the stored threepid - gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium) + gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium) assert.NoError(t, err, "unable to get localpart for threepid") assert.Equal(t, aliceLocalpart, gotLocalpart) + assert.Equal(t, aliceDomain, gotDomain) - threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 1, len(threepids)) assert.Equal(t, authtypes.ThreePID{ @@ -490,7 +491,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err, "unexpected error") // verify it was deleted - threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 0, len(threepids)) }) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 1cea5ff5e..4ae2961f4 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -92,9 +92,9 @@ type ProfileTable interface { } type ThreePIDTable interface { - SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) - SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) - InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) }