diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 71cead0d0..07dba66b1 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -38,13 +38,13 @@ CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The account data content content TEXT NOT NULL, - PRIMARY KEY(localpart, room_id, type) + PRIMARY KEY(localpart, server_name, room_id, type) ); ` const insertAccountDataSQL = ` INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content + ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content ` const selectAccountDataSQL = "" + diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 9b30b5d8f..1512eb818 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -35,7 +35,7 @@ const accountsSchema = ` -- Stores data about accounts. CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, server_name TEXT NOT NULL, -- When this account was first created, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, @@ -46,9 +46,10 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( -- If the account is currently active is_deactivated BOOLEAN DEFAULT FALSE, -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) - account_type SMALLINT NOT NULL + account_type SMALLINT NOT NULL, -- TODO: -- upgraded_ts, devices, any email reset stuff? + PRIMARY KEY (localpart, server_name) ); ` @@ -138,8 +139,8 @@ func (s *accountsStatements) InsertAccount( return &api.Account{ Localpart: localpart, - UserID: userutil.MakeUserID(localpart, s.serverName), - ServerName: s.serverName, + UserID: userutil.MakeUserID(localpart, serverName), + ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, }, nil @@ -185,9 +186,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( acc.AppServiceID = appserviceIDPtr.String } - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName - + acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName) return &acc, nil } diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index cd715437b..2dd216189 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -67,7 +67,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( ); -- Device IDs must be unique for a given user. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id); ` const insertDeviceSQL = "" + diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 953956aa5..68d87f007 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -60,7 +60,7 @@ func (s *openIDTokenStatements) InsertOpenIDToken( expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS) return } diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 28d24a4d0..0d0ff705c 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -30,12 +30,13 @@ const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, server_name TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account - avatar_url TEXT + avatar_url TEXT, + PRIMARY KEY (localpart, server_name) ); ` diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 8578e018e..f41c43122 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS userapi_threepids ( PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart); +CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name); ` const selectLocalpartForThreePIDSQL = "" + diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index c02ad6d12..4eac1ad56 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -293,7 +293,7 @@ func (d *Database) SaveThreePIDAssociation( medium string, ) (err error) { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - user, domain, err := d.ThreePIDs.SelectLocalpartForThreePID( + user, _, err := d.ThreePIDs.SelectLocalpartForThreePID( ctx, txn, threepid, medium, ) if err != nil { @@ -304,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation( return Err3PIDInUse } - return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, domain) + return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName) }) } diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index b8f125bdf..01ff6a8d3 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -37,13 +37,13 @@ CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The account data content content TEXT NOT NULL, - PRIMARY KEY(localpart, room_id, type) + PRIMARY KEY(localpart, server_name, room_id, type) ); ` const insertAccountDataSQL = ` INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 + ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $4 ` const selectAccountDataSQL = "" + diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 6f855b1e8..861d794d9 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -34,7 +34,7 @@ const accountsSchema = ` -- Stores data about accounts. CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, server_name TEXT NOT NULL, -- When this account was first created, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, @@ -45,9 +45,10 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( -- If the account is currently active is_deactivated BOOLEAN DEFAULT 0, -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) - account_type INTEGER NOT NULL + account_type INTEGER NOT NULL, -- TODO: -- upgraded_ts, devices, any email reset stuff? + PRIMARY KEY (localpart, server_name) ); ` @@ -138,8 +139,8 @@ func (s *accountsStatements) InsertAccount( return &api.Account{ Localpart: localpart, - UserID: userutil.MakeUserID(localpart, s.serverName), - ServerName: s.serverName, + UserID: userutil.MakeUserID(localpart, serverName), + ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, }, nil @@ -185,9 +186,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( acc.AppServiceID = appserviceIDPtr.String } - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName - + acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName) return &acc, nil } diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 95d92f45d..c5db34bd7 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -47,7 +47,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( ip TEXT, user_agent TEXT, - UNIQUE (localpart, device_id) + UNIQUE (localpart, server_name, device_id) ); ` @@ -186,7 +186,7 @@ func (s *devicesStatements) DeleteDevices( params[0] = localpart params[1] = serverName for i, v := range devices { - params[i+1] = v + params[i+2] = v } _, err = stmt.ExecContext(ctx, params...) return err diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index 0201c39ba..f06429741 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -62,7 +62,7 @@ func (s *openIDTokenStatements) InsertOpenIDToken( expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS) return } diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 4f01f3813..7b151adf6 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -30,12 +30,13 @@ const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, server_name TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account - avatar_url TEXT + avatar_url TEXT, + PRIMARY KEY (localpart, server_name) ); ` @@ -135,6 +136,7 @@ func (s *profilesStatements) SetDisplayName( ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, + ServerName: string(serverName), DisplayName: displayName, } old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 2d3c7f2ee..2db7d5887 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -40,11 +40,11 @@ CREATE TABLE IF NOT EXISTS userapi_threepids ( PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart); +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" + "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2" diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 79af7b853..23aafff03 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -377,7 +377,10 @@ func Test_Profile(t *testing.T) { gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get profile by localpart") - wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} + wantProfile := &authtypes.Profile{ + Localpart: aliceLocalpart, + ServerName: string(aliceDomain), + } assert.Equal(t, wantProfile, gotProfile) // set avatar & displayname