mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 14:21:55 -06:00
Fix SQLite session_id
(#2977)
This fixes an issue with device_id/session_ids. If a `device_id` is reused, we would reuse the same `session_id`, since we delete one device and insert a new one directly, resulting in the query to get a new `session_id` to return the previous session_id. (`SELECT count(access_token)`)
This commit is contained in:
parent
11d9b9db0e
commit
f0805071d5
|
@ -81,7 +81,7 @@ const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
"SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||||
|
@ -96,7 +96,7 @@ const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
|
||||||
|
|
||||||
const selectDevicesByIDSQL = "" +
|
const selectDevicesByIDSQL = "" +
|
||||||
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
"SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceLastSeen = "" +
|
const updateDeviceLastSeen = "" +
|
||||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||||
|
@ -171,6 +171,14 @@ func (s *devicesStatements) InsertDevice(
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||||
|
sessionID int64,
|
||||||
|
) (*api.Device, error) {
|
||||||
|
return s.InsertDevice(ctx, txn, id, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||||
|
}
|
||||||
|
|
||||||
// deleteDevice removes a single device by id and user localpart.
|
// deleteDevice removes a single device by id and user localpart.
|
||||||
func (s *devicesStatements) DeleteDevice(
|
func (s *devicesStatements) DeleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
|
@ -271,7 +279,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
var lastseents sql.NullInt64
|
var lastseents sql.NullInt64
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
|
@ -303,7 +311,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
var lastseents sql.NullInt64
|
var lastseents sql.NullInt64
|
||||||
var id, displayname, ip, useragent sql.NullString
|
var id, displayname, ip, useragent sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -588,6 +588,31 @@ func (d *Database) CreateDevice(
|
||||||
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
|
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
|
||||||
) (dev *api.Device, returnErr error) {
|
) (dev *api.Device, returnErr error) {
|
||||||
if deviceID != nil {
|
if deviceID != nil {
|
||||||
|
_, ok := d.Writer.(*sqlutil.ExclusiveWriter)
|
||||||
|
if ok { // we're using most likely using SQLite, so do things a little different
|
||||||
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
devices, err := d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, "")
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// No devices yet, only create a new one
|
||||||
|
if len(devices) == 0 {
|
||||||
|
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sessionID := devices[0].SessionID + 1
|
||||||
|
|
||||||
|
// Revoke existing tokens for this device
|
||||||
|
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Create a new device with the session ID incremented
|
||||||
|
dev, err = d.Devices.InsertDeviceWithSessionID(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent, sessionID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
} else {
|
||||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
// Revoke existing tokens for this device
|
// Revoke existing tokens for this device
|
||||||
|
@ -598,6 +623,7 @@ func (d *Database) CreateDevice(
|
||||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// We generate device IDs in a loop in case its already taken.
|
// We generate device IDs in a loop in case its already taken.
|
||||||
// We cap this at going round 5 times to ensure we don't spin forever
|
// We cap this at going round 5 times to ensure we don't spin forever
|
||||||
|
@ -618,7 +644,7 @@ func (d *Database) CreateDevice(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return dev, returnErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||||
|
|
|
@ -65,7 +65,7 @@ const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
"SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||||
|
@ -80,7 +80,7 @@ const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
|
||||||
|
|
||||||
const selectDevicesByIDSQL = "" +
|
const selectDevicesByIDSQL = "" +
|
||||||
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
"SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceLastSeen = "" +
|
const updateDeviceLastSeen = "" +
|
||||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||||
|
@ -162,6 +162,27 @@ func (s *devicesStatements) InsertDevice(
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||||
|
sessionID int64,
|
||||||
|
) (*api.Device, error) {
|
||||||
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
|
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||||
|
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &api.Device{
|
||||||
|
ID: id,
|
||||||
|
UserID: userutil.MakeUserID(localpart, serverName),
|
||||||
|
AccessToken: accessToken,
|
||||||
|
SessionID: sessionID,
|
||||||
|
LastSeenTS: createdTimeMS,
|
||||||
|
LastSeenIP: ipAddr,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) DeleteDevice(
|
func (s *devicesStatements) DeleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
localpart string, serverName gomatrixserverlib.ServerName,
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
@ -271,7 +292,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
var lastseents sql.NullInt64
|
var lastseents sql.NullInt64
|
||||||
var id, displayname, ip, useragent sql.NullString
|
var id, displayname, ip, useragent sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
}
|
}
|
||||||
|
@ -317,7 +338,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
var lastseents sql.NullInt64
|
var lastseents sql.NullInt64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
|
|
|
@ -44,6 +44,7 @@ type AccountsTable interface {
|
||||||
|
|
||||||
type DevicesTable interface {
|
type DevicesTable interface {
|
||||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||||
|
InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64) (*api.Device, error)
|
||||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
|
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
||||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
|
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
|
||||||
|
|
Loading…
Reference in a new issue