From d058e052fc5382f8b6dea5ac794a1bffbea2f32f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 7 Nov 2022 09:30:26 +0000 Subject: [PATCH] Devices table --- userapi/consumers/roomserver.go | 24 +++---- userapi/internal/api.go | 28 ++++++--- userapi/storage/interface.go | 12 ++-- userapi/storage/postgres/devices_table.go | 56 ++++++++++------- userapi/storage/shared/storage.go | 37 ++++++----- userapi/storage/sqlite3/devices_table.go | 76 ++++++++++++++--------- userapi/storage/storage_test.go | 20 +++--- userapi/storage/tables/interface.go | 14 ++--- 8 files changed, 155 insertions(+), 112 deletions(-) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 2a0d58b74..c87f1c574 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { for _, membership := range localMembers { // Copy any existing push rules from old -> new room - if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { return err } // preserve m.direct room state - if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil { + if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil { return err } // copy existing m.tag entries, if any - if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { return err } } return nil } -func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error { - pushRules, err := s.db.QueryPushRules(ctx, localpart) +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error { + pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName) if err != nil { return fmt.Errorf("failed to query pushrules for user: %w", err) } @@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, if err != nil { return err } - if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil { + if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil { return fmt.Errorf("failed to update pushrules: %w", err) } } @@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, } // updateMDirect copies the "is_direct" flag from oldRoomID to newROomID -func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error { +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error { // this is most likely not a DM, so skip updating m.direct state if roomSize > 2 { return nil } // Get direct message state - directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct") + directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct") if err != nil { return fmt.Errorf("failed to get m.direct from database: %w", err) } @@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, if err != nil { return true } - if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil { + if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil { return true } } @@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, return nil } -func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error { - tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag") +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error { + tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag") if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } if tag == nil { return nil } - return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag) + return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag) } func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 300067243..5021871d9 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -230,7 +230,7 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe return err } if req.LogoutDevices { - if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil { + if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil { return err } } @@ -243,7 +243,9 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe if serverName == "" { serverName = a.Config.Matrix.ServerName } - _ = serverName + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %s is not local", serverName) + } util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, @@ -274,12 +276,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) + devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { - err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs) + err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs) } if err != nil { return err @@ -333,23 +335,26 @@ func (a *UserInternalAPI) PerformLastSeenUpdate( req *api.PerformLastSeenUpdateRequest, res *api.PerformLastSeenUpdateResponse, ) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil } func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") return err } - dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID) + dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID) if err == sql.ErrNoRows { res.DeviceExists = false return nil @@ -357,6 +362,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") return err } + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } res.DeviceExists = true if dev.UserID != req.RequestingUserID { @@ -364,7 +372,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf return nil } - err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err @@ -455,7 +463,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) } - devs, err := a.DB.GetDevicesByLocalpart(ctx, local) + devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain) if err != nil { return err } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index d68a8b57d..4a0f66f85 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -61,8 +61,8 @@ type AccountData interface { type Device interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked @@ -71,11 +71,11 @@ type Device interface { // If no device ID is given one is generated. // Returns the device on success. CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error - RemoveDevices(ctx context.Context, localpart string, devices []string) error + UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error + RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error) } type KeyBackup interface { diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 2b90a69b2..db7125b22 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -75,7 +75,7 @@ const insertDeviceSQL = "" + " RETURNING session_id" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" @@ -90,16 +90,16 @@ const deleteDeviceSQL = "" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3" const deleteDevicesSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" type devicesStatements struct { insertDeviceStmt *sql.Stmt @@ -184,7 +184,9 @@ func (s *devicesStatements) DeleteDevice( // deleteDevices removes a single or multiple devices by ids and user localpart. // Returns an error if the execution failed. func (s *devicesStatements) DeleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) @@ -194,18 +196,22 @@ func (s *devicesStatements) DeleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. func (s *devicesStatements) DeleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID) return err } func (s *devicesStatements) UpdateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID) return err } @@ -214,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) dev.AccessToken = accessToken } return &dev, err @@ -226,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { var dev api.Device var displayName, ip sql.NullString var lastseenTS sql.NullInt64 stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) + err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) if displayName.Valid { dev.DisplayName = displayName.String } @@ -258,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName var lastseents sql.NullInt64 var displayName sql.NullString for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { @@ -270,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s if lastseents.Valid { dev.LastSeenTS = lastseents.Int64 } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } func (s *devicesStatements) SelectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID) if err != nil { return devices, err @@ -311,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart( dev.UserAgent = useragent.String } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) return err } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 88e821532..7f4ec8a1a 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -546,16 +546,19 @@ func (d *Database) GetDeviceByAccessToken( // GetDeviceByID returns the device matching the given ID. // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { - return d.Devices.SelectDeviceByID(ctx, localpart, deviceID) + return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID) } // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Device, error) { - return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "") + return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -576,7 +579,7 @@ func (d *Database) CreateDevice( returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device - if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil { + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { return err } @@ -621,10 +624,12 @@ func generateDeviceID() (string, error) { // UpdateDevice updates the given device with the display name. // Returns SQL error if there are problems and nil on success. func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName) + return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName) }) } @@ -633,10 +638,12 @@ func (d *Database) UpdateDevice( // If the devices don't exist, it will not return an error // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows { return err } return nil @@ -647,14 +654,16 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) (devices []api.Device, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID) if err != nil { return err } - if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows { return err } return nil @@ -663,9 +672,9 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a last seen timestamp and the ip address. -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent) + return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent) }) } diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 832abf36d..95d92f45d 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -59,31 +59,31 @@ const selectDevicesCountSQL = "" + "SELECT COUNT(access_token) FROM userapi_devices" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" const deleteDeviceSQL = "" + - "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3" const deleteDevicesSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" type devicesStatements struct { db *sql.DB @@ -153,7 +153,7 @@ func (s *devicesStatements) InsertDevice( } return &api.Device{ ID: id, - UserID: userutil.MakeUserID(localpart, s.serverName), + UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, @@ -163,24 +163,28 @@ func (s *devicesStatements) InsertDevice( } func (s *devicesStatements) DeleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + _, err := stmt.ExecContext(ctx, id, localpart, serverName) return err } func (s *devicesStatements) DeleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) + orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1) prep, err := s.db.Prepare(orig) if err != nil { return err } stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) + params := make([]interface{}, len(devices)+2) params[0] = localpart + params[1] = serverName for i, v := range devices { params[i+1] = v } @@ -189,18 +193,22 @@ func (s *devicesStatements) DeleteDevices( } func (s *devicesStatements) DeleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID) return err } func (s *devicesStatements) UpdateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID) return err } @@ -209,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) dev.AccessToken = accessToken } return &dev, err @@ -221,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { var dev api.Device var displayName, ip sql.NullString stmt := s.selectDeviceByIDStmt var lastseenTS sql.NullInt64 - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) + err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) if displayName.Valid { dev.DisplayName = displayName.String } @@ -245,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID( } func (s *devicesStatements) SelectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID) if err != nil { return devices, err @@ -278,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( dev.UserAgent = useragent.String } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } @@ -300,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { @@ -312,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s if lastseents.Valid { dev.LastSeenTS = lastseents.Int64 } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) return err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 22798f029..c7bf7dc7f 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -164,7 +164,7 @@ func Test_Devices(t *testing.T) { deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") - gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID) + gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields @@ -176,12 +176,12 @@ func Test_Devices(t *testing.T) { accessToken = util.RandomString(16) deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") - gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID) + gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields // Get devices - devices, err := db.GetDevicesByLocalpart(ctx, localpart) + devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get devices by localpart") assert.Equal(t, 2, len(devices)) deviceIDs := make([]string, 0, len(devices)) @@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) { // Update device newName := "new display name" - err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName) + err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName) assert.NoError(t, err, "unable to update device displayname") updatedAfterTimestamp := time.Now().Unix() - err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web") + err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web") assert.NoError(t, err, "unable to update device last seen") deviceWithID.DisplayName = newName deviceWithID.LastSeenIP = "127.0.0.1" - gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID) + gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 2, len(devices)) assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName) @@ -216,17 +216,17 @@ func Test_Devices(t *testing.T) { _, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create new device") - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 3, len(devices)) - err = db.RemoveDevices(ctx, localpart, deviceIDs) + err = db.RemoveDevices(ctx, localpart, domain, deviceIDs) assert.NoError(t, err, "unable to remove devices") - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 1, len(devices)) - deleted, err := db.RemoveAllDevices(ctx, localpart, "") + deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "") assert.NoError(t, err, "unable to remove all devices") assert.Equal(t, 1, len(deleted)) assert.Equal(t, newDeviceID, deleted[0].ID) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 1747e9256..e5ee0daaf 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -44,15 +44,15 @@ type AccountsTable interface { type DevicesTable interface { InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) - DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error - DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error - DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error - UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) - SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) + SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error } type KeyBackupTable interface {