Fix accounts DB, device DB

This commit is contained in:
Neil Alexander 2020-01-27 17:36:43 +00:00
parent db192382c6
commit 99b32109ba
2 changed files with 40 additions and 32 deletions

View file

@ -28,13 +28,11 @@ CREATE TABLE IF NOT EXISTS account_filter (
-- The filter
filter TEXT NOT NULL,
-- The ID
id SERIAL,
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The localpart of the Matrix user ID associated to this filter
localpart TEXT NOT NULL,
PRIMARY KEY(id, localpart),
UNIQUE (id)
UNIQUE (id, localpart)
);
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
@ -49,10 +47,14 @@ const selectFilterIDByContentSQL = "" +
const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
const selectLastInsertedFilterIDSQL = "" +
"SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
type filterStatements struct {
selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
selectFilterStmt *sql.Stmt
selectLastInsertedFilterIDStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
}
func (s *filterStatements) prepare(db *sql.DB) (err error) {
@ -63,6 +65,9 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return
}
if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
return
}
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return
}
@ -123,7 +128,12 @@ func (s *filterStatements) insertFilter(
}
// Otherwise insert the filter and return the new ID
err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart).
Scan(&filterID)
if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil {
return "", err
}
row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx)
if err := row.Scan(&filterID); err != nil {
return "", err
}
return
}

View file

@ -32,34 +32,23 @@ const devicesSchema = `
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
-- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY,
-- The auto-allocated unique ID of the session identified by the access token.
-- This can be used as a secure substitution of the access token in situations
-- where data is associated with access tokens (e.g. transaction storage),
-- so we don't have to store users' access tokens everywhere.
session_id BIGINT,
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user.
device_id TEXT NOT NULL,
-- The Matrix user ID localpart for this device. This is preferable to storing the full user_id
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
-- migration to different domain names easier.
localpart TEXT NOT NULL,
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The display name, human friendlier than device_id and updatable
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
UNIQUE (localpart, device_id)
);
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5);" +
"SELECT last_insert_rowid() AS session_id"
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" +
" VALUES ($1, $2, $3, $4, $5, $6)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
@ -81,6 +70,7 @@ const deleteDevicesByLocalpartSQL = "" +
type devicesStatements struct {
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
@ -98,6 +88,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
@ -129,8 +122,13 @@ func (s *devicesStatements) insertDevice(
) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := common.TxStmt(txn, s.insertDeviceStmt)
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
countStmt := common.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := common.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return nil, err
}
return &authtypes.Device{