Devices table
This commit is contained in:
parent
b01510c49a
commit
d058e052fc
|
@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
|
||||||
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
|
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
|
||||||
for _, membership := range localMembers {
|
for _, membership := range localMembers {
|
||||||
// Copy any existing push rules from old -> new room
|
// Copy any existing push rules from old -> new room
|
||||||
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
|
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// preserve m.direct room state
|
// preserve m.direct room state
|
||||||
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil {
|
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy existing m.tag entries, if any
|
// copy existing m.tag entries, if any
|
||||||
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
|
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error {
|
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error {
|
||||||
pushRules, err := s.db.QueryPushRules(ctx, localpart)
|
pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to query pushrules for user: %w", err)
|
return fmt.Errorf("failed to query pushrules for user: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil {
|
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil {
|
||||||
return fmt.Errorf("failed to update pushrules: %w", err)
|
return fmt.Errorf("failed to update pushrules: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
|
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
|
||||||
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error {
|
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error {
|
||||||
// this is most likely not a DM, so skip updating m.direct state
|
// this is most likely not a DM, so skip updating m.direct state
|
||||||
if roomSize > 2 {
|
if roomSize > 2 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// Get direct message state
|
// Get direct message state
|
||||||
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct")
|
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get m.direct from database: %w", err)
|
return fmt.Errorf("failed to get m.direct from database: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil {
|
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error {
|
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error {
|
||||||
tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag")
|
tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag")
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if tag == nil {
|
if tag == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag)
|
return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
|
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
|
||||||
|
|
|
@ -230,7 +230,7 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if req.LogoutDevices {
|
if req.LogoutDevices {
|
||||||
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
|
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -243,7 +243,9 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
||||||
if serverName == "" {
|
if serverName == "" {
|
||||||
serverName = a.Config.Matrix.ServerName
|
serverName = a.Config.Matrix.ServerName
|
||||||
}
|
}
|
||||||
_ = serverName
|
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||||
|
return fmt.Errorf("server name %s is not local", serverName)
|
||||||
|
}
|
||||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
"localpart": req.Localpart,
|
"localpart": req.Localpart,
|
||||||
"device_id": req.DeviceID,
|
"device_id": req.DeviceID,
|
||||||
|
@ -274,12 +276,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
||||||
deletedDeviceIDs := req.DeviceIDs
|
deletedDeviceIDs := req.DeviceIDs
|
||||||
if len(req.DeviceIDs) == 0 {
|
if len(req.DeviceIDs) == 0 {
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
|
||||||
for _, d := range devices {
|
for _, d := range devices {
|
||||||
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
|
err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -333,23 +335,26 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
|
||||||
req *api.PerformLastSeenUpdateRequest,
|
req *api.PerformLastSeenUpdateRequest,
|
||||||
res *api.PerformLastSeenUpdateResponse,
|
res *api.PerformLastSeenUpdateResponse,
|
||||||
) error {
|
) error {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||||
}
|
}
|
||||||
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
|
return fmt.Errorf("server name %s is not local", domain)
|
||||||
|
}
|
||||||
|
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
||||||
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
res.DeviceExists = false
|
res.DeviceExists = false
|
||||||
return nil
|
return nil
|
||||||
|
@ -357,6 +362,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
||||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed")
|
util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
|
return fmt.Errorf("server name %s is not local", domain)
|
||||||
|
}
|
||||||
res.DeviceExists = true
|
res.DeviceExists = true
|
||||||
|
|
||||||
if dev.UserID != req.RequestingUserID {
|
if dev.UserID != req.RequestingUserID {
|
||||||
|
@ -364,7 +372,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||||
return err
|
return err
|
||||||
|
@ -455,7 +463,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
|
||||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
|
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
|
||||||
}
|
}
|
||||||
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
|
devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,8 +61,8 @@ type AccountData interface {
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
|
||||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error)
|
||||||
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
|
@ -71,11 +71,11 @@ type Device interface {
|
||||||
// If no device ID is given one is generated.
|
// If no device ID is given one is generated.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
|
||||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
|
UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
|
||||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
||||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyBackup interface {
|
type KeyBackup interface {
|
||||||
|
|
|
@ -75,7 +75,7 @@ const insertDeviceSQL = "" +
|
||||||
" RETURNING session_id"
|
" RETURNING session_id"
|
||||||
|
|
||||||
const selectDeviceByTokenSQL = "" +
|
const selectDeviceByTokenSQL = "" +
|
||||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
|
||||||
|
|
||||||
const selectDeviceByIDSQL = "" +
|
const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
||||||
|
@ -90,16 +90,16 @@ const deleteDeviceSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||||
|
|
||||||
const deleteDevicesByLocalpartSQL = "" +
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
|
||||||
|
|
||||||
const deleteDevicesSQL = "" +
|
const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
|
||||||
|
|
||||||
const selectDevicesByIDSQL = "" +
|
const selectDevicesByIDSQL = "" +
|
||||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceLastSeen = "" +
|
const updateDeviceLastSeen = "" +
|
||||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
|
@ -184,7 +184,9 @@ func (s *devicesStatements) DeleteDevice(
|
||||||
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
||||||
// Returns an error if the execution failed.
|
// Returns an error if the execution failed.
|
||||||
func (s *devicesStatements) DeleteDevices(
|
func (s *devicesStatements) DeleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
devices []string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
||||||
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
|
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
|
||||||
|
@ -194,18 +196,22 @@ func (s *devicesStatements) DeleteDevices(
|
||||||
// deleteDevicesByLocalpart removes all devices for the
|
// deleteDevicesByLocalpart removes all devices for the
|
||||||
// given user localpart.
|
// given user localpart.
|
||||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) UpdateDeviceName(
|
func (s *devicesStatements) UpdateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
stmt := s.selectDeviceByTokenStmt
|
stmt := s.selectDeviceByTokenStmt
|
||||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
dev.AccessToken = accessToken
|
dev.AccessToken = accessToken
|
||||||
}
|
}
|
||||||
return &dev, err
|
return &dev, err
|
||||||
|
@ -226,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||||
// selectDeviceByID retrieves a device from the database with the given user
|
// selectDeviceByID retrieves a device from the database with the given user
|
||||||
// localpart and deviceID
|
// localpart and deviceID
|
||||||
func (s *devicesStatements) SelectDeviceByID(
|
func (s *devicesStatements) SelectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var displayName, ip sql.NullString
|
var displayName, ip sql.NullString
|
||||||
var lastseenTS sql.NullInt64
|
var lastseenTS sql.NullInt64
|
||||||
stmt := s.selectDeviceByIDStmt
|
stmt := s.selectDeviceByIDStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.ID = deviceID
|
dev.ID = deviceID
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dev.DisplayName = displayName.String
|
dev.DisplayName = displayName.String
|
||||||
}
|
}
|
||||||
|
@ -258,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
var lastseents sql.NullInt64
|
var lastseents sql.NullInt64
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
|
@ -270,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
if lastseents.Valid {
|
if lastseents.Valid {
|
||||||
dev.LastSeenTS = lastseents.Int64
|
dev.LastSeenTS = lastseents.Int64
|
||||||
}
|
}
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
return devices, rows.Err()
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
devices := []api.Device{}
|
devices := []api.Device{}
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
|
@ -311,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
dev.UserAgent = useragent.String
|
dev.UserAgent = useragent.String
|
||||||
}
|
}
|
||||||
|
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
|
|
||||||
return devices, rows.Err()
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
|
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -546,16 +546,19 @@ func (d *Database) GetDeviceByAccessToken(
|
||||||
// GetDeviceByID returns the device matching the given ID.
|
// GetDeviceByID returns the device matching the given ID.
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
func (d *Database) GetDeviceByID(
|
func (d *Database) GetDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
|
return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
|
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
@ -576,7 +579,7 @@ func (d *Database) CreateDevice(
|
||||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
// Revoke existing tokens for this device
|
// Revoke existing tokens for this device
|
||||||
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -621,10 +624,12 @@ func generateDeviceID() (string, error) {
|
||||||
// UpdateDevice updates the given device with the display name.
|
// UpdateDevice updates the given device with the display name.
|
||||||
// Returns SQL error if there are problems and nil on success.
|
// Returns SQL error if there are problems and nil on success.
|
||||||
func (d *Database) UpdateDevice(
|
func (d *Database) UpdateDevice(
|
||||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -633,10 +638,12 @@ func (d *Database) UpdateDevice(
|
||||||
// If the devices don't exist, it will not return an error
|
// If the devices don't exist, it will not return an error
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
func (d *Database) RemoveDevices(
|
func (d *Database) RemoveDevices(
|
||||||
ctx context.Context, localpart string, devices []string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
devices []string,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -647,14 +654,16 @@ func (d *Database) RemoveDevices(
|
||||||
// database matching the given user ID localpart.
|
// database matching the given user ID localpart.
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
func (d *Database) RemoveAllDevices(
|
func (d *Database) RemoveAllDevices(
|
||||||
ctx context.Context, localpart, exceptDeviceID string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) (devices []api.Device, err error) {
|
) (devices []api.Device, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -663,9 +672,9 @@ func (d *Database) RemoveAllDevices(
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
|
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
|
||||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error {
|
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent)
|
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,31 +59,31 @@ const selectDevicesCountSQL = "" +
|
||||||
"SELECT COUNT(access_token) FROM userapi_devices"
|
"SELECT COUNT(access_token) FROM userapi_devices"
|
||||||
|
|
||||||
const selectDeviceByTokenSQL = "" +
|
const selectDeviceByTokenSQL = "" +
|
||||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
|
||||||
|
|
||||||
const selectDeviceByIDSQL = "" +
|
const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||||
|
|
||||||
const deleteDeviceSQL = "" +
|
const deleteDeviceSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||||
|
|
||||||
const deleteDevicesByLocalpartSQL = "" +
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
|
||||||
|
|
||||||
const deleteDevicesSQL = "" +
|
const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
|
||||||
|
|
||||||
const selectDevicesByIDSQL = "" +
|
const selectDevicesByIDSQL = "" +
|
||||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceLastSeen = "" +
|
const updateDeviceLastSeen = "" +
|
||||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -153,7 +153,7 @@ func (s *devicesStatements) InsertDevice(
|
||||||
}
|
}
|
||||||
return &api.Device{
|
return &api.Device{
|
||||||
ID: id,
|
ID: id,
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, serverName),
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
LastSeenTS: createdTimeMS,
|
LastSeenTS: createdTimeMS,
|
||||||
|
@ -163,24 +163,28 @@ func (s *devicesStatements) InsertDevice(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) DeleteDevice(
|
func (s *devicesStatements) DeleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) DeleteDevices(
|
func (s *devicesStatements) DeleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
devices []string,
|
||||||
) error {
|
) error {
|
||||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
|
||||||
prep, err := s.db.Prepare(orig)
|
prep, err := s.db.Prepare(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, prep)
|
stmt := sqlutil.TxStmt(txn, prep)
|
||||||
params := make([]interface{}, len(devices)+1)
|
params := make([]interface{}, len(devices)+2)
|
||||||
params[0] = localpart
|
params[0] = localpart
|
||||||
|
params[1] = serverName
|
||||||
for i, v := range devices {
|
for i, v := range devices {
|
||||||
params[i+1] = v
|
params[i+1] = v
|
||||||
}
|
}
|
||||||
|
@ -189,18 +193,22 @@ func (s *devicesStatements) DeleteDevices(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) UpdateDeviceName(
|
func (s *devicesStatements) UpdateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
stmt := s.selectDeviceByTokenStmt
|
stmt := s.selectDeviceByTokenStmt
|
||||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
dev.AccessToken = accessToken
|
dev.AccessToken = accessToken
|
||||||
}
|
}
|
||||||
return &dev, err
|
return &dev, err
|
||||||
|
@ -221,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||||
// selectDeviceByID retrieves a device from the database with the given user
|
// selectDeviceByID retrieves a device from the database with the given user
|
||||||
// localpart and deviceID
|
// localpart and deviceID
|
||||||
func (s *devicesStatements) SelectDeviceByID(
|
func (s *devicesStatements) SelectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var displayName, ip sql.NullString
|
var displayName, ip sql.NullString
|
||||||
stmt := s.selectDeviceByIDStmt
|
stmt := s.selectDeviceByIDStmt
|
||||||
var lastseenTS sql.NullInt64
|
var lastseenTS sql.NullInt64
|
||||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.ID = deviceID
|
dev.ID = deviceID
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dev.DisplayName = displayName.String
|
dev.DisplayName = displayName.String
|
||||||
}
|
}
|
||||||
|
@ -245,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
devices := []api.Device{}
|
devices := []api.Device{}
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
|
@ -278,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
dev.UserAgent = useragent.String
|
dev.UserAgent = useragent.String
|
||||||
}
|
}
|
||||||
|
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
var lastseents sql.NullInt64
|
var lastseents sql.NullInt64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
|
@ -312,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
if lastseents.Valid {
|
if lastseents.Valid {
|
||||||
dev.LastSeenTS = lastseents.Int64
|
dev.LastSeenTS = lastseents.Int64
|
||||||
}
|
}
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
return devices, rows.Err()
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
|
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -164,7 +164,7 @@ func Test_Devices(t *testing.T) {
|
||||||
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
||||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||||
|
|
||||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
||||||
|
|
||||||
|
@ -176,12 +176,12 @@ func Test_Devices(t *testing.T) {
|
||||||
accessToken = util.RandomString(16)
|
accessToken = util.RandomString(16)
|
||||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
|
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
|
||||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
||||||
|
|
||||||
// Get devices
|
// Get devices
|
||||||
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||||
assert.NoError(t, err, "unable to get devices by localpart")
|
assert.NoError(t, err, "unable to get devices by localpart")
|
||||||
assert.Equal(t, 2, len(devices))
|
assert.Equal(t, 2, len(devices))
|
||||||
deviceIDs := make([]string, 0, len(devices))
|
deviceIDs := make([]string, 0, len(devices))
|
||||||
|
@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) {
|
||||||
|
|
||||||
// Update device
|
// Update device
|
||||||
newName := "new display name"
|
newName := "new display name"
|
||||||
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
|
err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName)
|
||||||
assert.NoError(t, err, "unable to update device displayname")
|
assert.NoError(t, err, "unable to update device displayname")
|
||||||
updatedAfterTimestamp := time.Now().Unix()
|
updatedAfterTimestamp := time.Now().Unix()
|
||||||
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web")
|
err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web")
|
||||||
assert.NoError(t, err, "unable to update device last seen")
|
assert.NoError(t, err, "unable to update device last seen")
|
||||||
|
|
||||||
deviceWithID.DisplayName = newName
|
deviceWithID.DisplayName = newName
|
||||||
deviceWithID.LastSeenIP = "127.0.0.1"
|
deviceWithID.LastSeenIP = "127.0.0.1"
|
||||||
gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
|
gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, 2, len(devices))
|
assert.Equal(t, 2, len(devices))
|
||||||
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
|
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
|
||||||
|
@ -216,17 +216,17 @@ func Test_Devices(t *testing.T) {
|
||||||
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
|
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
|
||||||
assert.NoError(t, err, "unable to create new device")
|
assert.NoError(t, err, "unable to create new device")
|
||||||
|
|
||||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, 3, len(devices))
|
assert.Equal(t, 3, len(devices))
|
||||||
|
|
||||||
err = db.RemoveDevices(ctx, localpart, deviceIDs)
|
err = db.RemoveDevices(ctx, localpart, domain, deviceIDs)
|
||||||
assert.NoError(t, err, "unable to remove devices")
|
assert.NoError(t, err, "unable to remove devices")
|
||||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, 1, len(devices))
|
assert.Equal(t, 1, len(devices))
|
||||||
|
|
||||||
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
|
deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "")
|
||||||
assert.NoError(t, err, "unable to remove all devices")
|
assert.NoError(t, err, "unable to remove all devices")
|
||||||
assert.Equal(t, 1, len(deleted))
|
assert.Equal(t, 1, len(deleted))
|
||||||
assert.Equal(t, newDeviceID, deleted[0].ID)
|
assert.Equal(t, newDeviceID, deleted[0].ID)
|
||||||
|
|
|
@ -44,15 +44,15 @@ type AccountsTable interface {
|
||||||
|
|
||||||
type DevicesTable interface {
|
type DevicesTable interface {
|
||||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
|
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
|
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
||||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
|
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
|
||||||
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
|
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
|
||||||
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
|
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
|
||||||
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
|
||||||
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
|
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error)
|
||||||
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||||
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error
|
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyBackupTable interface {
|
type KeyBackupTable interface {
|
||||||
|
|
Loading…
Reference in a new issue