Devices table

This commit is contained in:
Neil Alexander 2022-11-07 09:30:26 +00:00
parent b01510c49a
commit d058e052fc
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 155 additions and 112 deletions

View file

@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
for _, membership := range localMembers {
// Copy any existing push rules from old -> new room
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
return err
}
// preserve m.direct room state
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil {
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil {
return err
}
// copy existing m.tag entries, if any
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
return err
}
}
return nil
}
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error {
pushRules, err := s.db.QueryPushRules(ctx, localpart)
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error {
pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName)
if err != nil {
return fmt.Errorf("failed to query pushrules for user: %w", err)
}
@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
if err != nil {
return err
}
if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil {
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil {
return fmt.Errorf("failed to update pushrules: %w", err)
}
}
@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
}
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error {
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error {
// this is most likely not a DM, so skip updating m.direct state
if roomSize > 2 {
return nil
}
// Get direct message state
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct")
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct")
if err != nil {
return fmt.Errorf("failed to get m.direct from database: %w", err)
}
@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
if err != nil {
return true
}
if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil {
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil {
return true
}
}
@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
return nil
}
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error {
tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag")
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error {
tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag")
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if tag == nil {
return nil
}
return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag)
return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag)
}
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {

View file

@ -230,7 +230,7 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe
return err
}
if req.LogoutDevices {
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
return err
}
}
@ -243,7 +243,9 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
_ = serverName
if !a.Config.Matrix.IsLocalServerName(serverName) {
return fmt.Errorf("server name %s is not local", serverName)
}
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"device_id": req.DeviceID,
@ -274,12 +276,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
}
if err != nil {
return err
@ -333,23 +335,26 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
req *api.PerformLastSeenUpdateRequest,
res *api.PerformLastSeenUpdateResponse,
) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("server name %s is not local", domain)
}
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
}
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@ -357,6 +362,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed")
return err
}
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("server name %s is not local", domain)
}
res.DeviceExists = true
if dev.UserID != req.RequestingUserID {
@ -364,7 +372,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@ -455,7 +463,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
}
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
if err != nil {
return err
}

View file

@ -61,8 +61,8 @@ type AccountData interface {
type Device interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
@ -71,11 +71,11 @@ type Device interface {
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error)
}
type KeyBackup interface {

View file

@ -75,7 +75,7 @@ const insertDeviceSQL = "" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
@ -90,16 +90,16 @@ const deleteDeviceSQL = "" +
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
const deleteDevicesSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
type devicesStatements struct {
insertDeviceStmt *sql.Stmt
@ -184,7 +184,9 @@ func (s *devicesStatements) DeleteDevice(
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
devices []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
@ -194,18 +196,22 @@ func (s *devicesStatements) DeleteDevices(
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
return err
}
func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
return err
}
@ -214,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken(
) (*api.Device, error) {
var dev api.Device
var localpart string
var serverName gomatrixserverlib.ServerName
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
dev.AccessToken = accessToken
}
return &dev, err
@ -226,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName, ip sql.NullString
var lastseenTS sql.NullInt64
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
@ -258,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var devices []api.Device
var dev api.Device
var localpart string
var serverName gomatrixserverlib.ServerName
var lastseents sql.NullInt64
var displayName sql.NullString
for rows.Next() {
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
return nil, err
}
if displayName.Valid {
@ -270,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
if err != nil {
return devices, err
@ -311,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
dev.UserAgent = useragent.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
return err
}

View file

@ -546,16 +546,19 @@ func (d *Database) GetDeviceByAccessToken(
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string,
) (*api.Device, error) {
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Device, error) {
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -576,7 +579,7 @@ func (d *Database) CreateDevice(
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
return err
}
@ -621,10 +624,12 @@ func generateDeviceID() (string, error) {
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string, displayName *string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName)
})
}
@ -633,10 +638,12 @@ func (d *Database) UpdateDevice(
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
devices []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows {
return err
}
return nil
@ -647,14 +654,16 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID)
if err != nil {
return err
}
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
@ -663,9 +672,9 @@ func (d *Database) RemoveAllDevices(
}
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error {
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent)
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
})
}

View file

@ -59,31 +59,31 @@ const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM userapi_devices"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
const deleteDeviceSQL = "" +
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
const deleteDevicesSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
type devicesStatements struct {
db *sql.DB
@ -153,7 +153,7 @@ func (s *devicesStatements) InsertDevice(
}
return &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
SessionID: sessionID,
LastSeenTS: createdTimeMS,
@ -163,24 +163,28 @@ func (s *devicesStatements) InsertDevice(
}
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
return err
}
func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params := make([]interface{}, len(devices)+2)
params[0] = localpart
params[1] = serverName
for i, v := range devices {
params[i+1] = v
}
@ -189,18 +193,22 @@ func (s *devicesStatements) DeleteDevices(
}
func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
return err
}
func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
return err
}
@ -209,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
) (*api.Device, error) {
var dev api.Device
var localpart string
var serverName gomatrixserverlib.ServerName
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
dev.AccessToken = accessToken
}
return &dev, err
@ -221,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName, ip sql.NullString
stmt := s.selectDeviceByIDStmt
var lastseenTS sql.NullInt64
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
@ -245,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
}
func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
if err != nil {
return devices, err
@ -278,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
dev.UserAgent = useragent.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
@ -300,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var devices []api.Device
var dev api.Device
var localpart string
var serverName gomatrixserverlib.ServerName
var displayName sql.NullString
var lastseents sql.NullInt64
for rows.Next() {
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
return nil, err
}
if displayName.Valid {
@ -312,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
return err
}

View file

@ -164,7 +164,7 @@ func Test_Devices(t *testing.T) {
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
@ -176,12 +176,12 @@ func Test_Devices(t *testing.T) {
accessToken = util.RandomString(16)
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
// Get devices
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get devices by localpart")
assert.Equal(t, 2, len(devices))
deviceIDs := make([]string, 0, len(devices))
@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) {
// Update device
newName := "new display name"
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName)
assert.NoError(t, err, "unable to update device displayname")
updatedAfterTimestamp := time.Now().Unix()
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web")
err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web")
assert.NoError(t, err, "unable to update device last seen")
deviceWithID.DisplayName = newName
deviceWithID.LastSeenIP = "127.0.0.1"
gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 2, len(devices))
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
@ -216,17 +216,17 @@ func Test_Devices(t *testing.T) {
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create new device")
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 3, len(devices))
err = db.RemoveDevices(ctx, localpart, deviceIDs)
err = db.RemoveDevices(ctx, localpart, domain, deviceIDs)
assert.NoError(t, err, "unable to remove devices")
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 1, len(devices))
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "")
assert.NoError(t, err, "unable to remove all devices")
assert.Equal(t, 1, len(deleted))
assert.Equal(t, newDeviceID, deleted[0].ID)

View file

@ -44,15 +44,15 @@ type AccountsTable interface {
type DevicesTable interface {
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error)
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
}
type KeyBackupTable interface {