diff --git a/clientapi/auth/storage/accounts/sqlite3/filter_table.go b/clientapi/auth/storage/accounts/sqlite3/filter_table.go index a6f4f7bb7..691ead775 100644 --- a/clientapi/auth/storage/accounts/sqlite3/filter_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/filter_table.go @@ -28,13 +28,11 @@ CREATE TABLE IF NOT EXISTS account_filter ( -- The filter filter TEXT NOT NULL, -- The ID - id SERIAL, + id INTEGER PRIMARY KEY AUTOINCREMENT, -- The localpart of the Matrix user ID associated to this filter localpart TEXT NOT NULL, - PRIMARY KEY(id, localpart), - - UNIQUE (id) + UNIQUE (id, localpart) ); CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart); @@ -49,10 +47,14 @@ const selectFilterIDByContentSQL = "" + const insertFilterSQL = "" + "INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)" +const selectLastInsertedFilterIDSQL = "" + + "SELECT id FROM account_filter WHERE rowid = last_insert_rowid()" + type filterStatements struct { - selectFilterStmt *sql.Stmt - selectFilterIDByContentStmt *sql.Stmt - insertFilterStmt *sql.Stmt + selectFilterStmt *sql.Stmt + selectLastInsertedFilterIDStmt *sql.Stmt + selectFilterIDByContentStmt *sql.Stmt + insertFilterStmt *sql.Stmt } func (s *filterStatements) prepare(db *sql.DB) (err error) { @@ -63,6 +65,9 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return } + if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil { + return + } if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { return } @@ -123,7 +128,12 @@ func (s *filterStatements) insertFilter( } // Otherwise insert the filter and return the new ID - err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart). - Scan(&filterID) + if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil { + return "", err + } + row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx) + if err := row.Scan(&filterID); err != nil { + return "", err + } return } diff --git a/clientapi/auth/storage/devices/sqlite3/devices_table.go b/clientapi/auth/storage/devices/sqlite3/devices_table.go index b12a0b869..d4349c99f 100644 --- a/clientapi/auth/storage/devices/sqlite3/devices_table.go +++ b/clientapi/auth/storage/devices/sqlite3/devices_table.go @@ -32,34 +32,23 @@ const devicesSchema = ` -- Stores data about devices. CREATE TABLE IF NOT EXISTS device_devices ( - -- The access token granted to this device. This has to be the primary key - -- so we can distinguish which device is making a given request. - access_token TEXT NOT NULL PRIMARY KEY, - -- The auto-allocated unique ID of the session identified by the access token. - -- This can be used as a secure substitution of the access token in situations - -- where data is associated with access tokens (e.g. transaction storage), - -- so we don't have to store users' access tokens everywhere. - session_id BIGINT, - -- The device identifier. This only needs to uniquely identify a device for a given user, not globally. - -- access_tokens will be clobbered based on the device ID for a user. - device_id TEXT NOT NULL, - -- The Matrix user ID localpart for this device. This is preferable to storing the full user_id - -- as it is smaller, makes it clearer that we only manage devices for our own users, and may make - -- migration to different domain names easier. - localpart TEXT NOT NULL, - -- When this devices was first recognised on the network, as a unix timestamp (ms resolution). - created_ts BIGINT NOT NULL, - -- The display name, human friendlier than device_id and updatable + access_token TEXT PRIMARY KEY, + session_id INTEGER, + device_id TEXT , + localpart TEXT , + created_ts BIGINT, display_name TEXT, - -- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app) UNIQUE (localpart, device_id) ); ` const insertDeviceSQL = "" + - "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5);" + - "SELECT last_insert_rowid() AS session_id" + "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + +const selectDevicesCountSQL = "" + + "SELECT COUNT(access_token) FROM device_devices" const selectDeviceByTokenSQL = "" + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" @@ -81,6 +70,7 @@ const deleteDevicesByLocalpartSQL = "" + type devicesStatements struct { insertDeviceStmt *sql.Stmt + selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt @@ -98,6 +88,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { return } + if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { + return + } if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { return } @@ -129,8 +122,13 @@ func (s *devicesStatements) insertDevice( ) (*authtypes.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - stmt := common.TxStmt(txn, s.insertDeviceStmt) - if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil { + countStmt := common.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := common.TxStmt(txn, s.insertDeviceStmt) + if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + return nil, err + } + sessionID++ + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { return nil, err } return &authtypes.Device{