Fix accounts DB, device DB
This commit is contained in:
parent
db192382c6
commit
99b32109ba
|
@ -28,13 +28,11 @@ CREATE TABLE IF NOT EXISTS account_filter (
|
||||||
-- The filter
|
-- The filter
|
||||||
filter TEXT NOT NULL,
|
filter TEXT NOT NULL,
|
||||||
-- The ID
|
-- The ID
|
||||||
id SERIAL,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
-- The localpart of the Matrix user ID associated to this filter
|
-- The localpart of the Matrix user ID associated to this filter
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
|
||||||
PRIMARY KEY(id, localpart),
|
UNIQUE (id, localpart)
|
||||||
|
|
||||||
UNIQUE (id)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
|
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
|
||||||
|
@ -49,8 +47,12 @@ const selectFilterIDByContentSQL = "" +
|
||||||
const insertFilterSQL = "" +
|
const insertFilterSQL = "" +
|
||||||
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
|
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
|
||||||
|
|
||||||
|
const selectLastInsertedFilterIDSQL = "" +
|
||||||
|
"SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
|
||||||
|
|
||||||
type filterStatements struct {
|
type filterStatements struct {
|
||||||
selectFilterStmt *sql.Stmt
|
selectFilterStmt *sql.Stmt
|
||||||
|
selectLastInsertedFilterIDStmt *sql.Stmt
|
||||||
selectFilterIDByContentStmt *sql.Stmt
|
selectFilterIDByContentStmt *sql.Stmt
|
||||||
insertFilterStmt *sql.Stmt
|
insertFilterStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
@ -63,6 +65,9 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
|
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -123,7 +128,12 @@ func (s *filterStatements) insertFilter(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise insert the filter and return the new ID
|
// Otherwise insert the filter and return the new ID
|
||||||
err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart).
|
if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil {
|
||||||
Scan(&filterID)
|
return "", err
|
||||||
|
}
|
||||||
|
row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx)
|
||||||
|
if err := row.Scan(&filterID); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,34 +32,23 @@ const devicesSchema = `
|
||||||
|
|
||||||
-- Stores data about devices.
|
-- Stores data about devices.
|
||||||
CREATE TABLE IF NOT EXISTS device_devices (
|
CREATE TABLE IF NOT EXISTS device_devices (
|
||||||
-- The access token granted to this device. This has to be the primary key
|
access_token TEXT PRIMARY KEY,
|
||||||
-- so we can distinguish which device is making a given request.
|
session_id INTEGER,
|
||||||
access_token TEXT NOT NULL PRIMARY KEY,
|
device_id TEXT ,
|
||||||
-- The auto-allocated unique ID of the session identified by the access token.
|
localpart TEXT ,
|
||||||
-- This can be used as a secure substitution of the access token in situations
|
created_ts BIGINT,
|
||||||
-- where data is associated with access tokens (e.g. transaction storage),
|
|
||||||
-- so we don't have to store users' access tokens everywhere.
|
|
||||||
session_id BIGINT,
|
|
||||||
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
|
|
||||||
-- access_tokens will be clobbered based on the device ID for a user.
|
|
||||||
device_id TEXT NOT NULL,
|
|
||||||
-- The Matrix user ID localpart for this device. This is preferable to storing the full user_id
|
|
||||||
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
|
|
||||||
-- migration to different domain names easier.
|
|
||||||
localpart TEXT NOT NULL,
|
|
||||||
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
|
||||||
created_ts BIGINT NOT NULL,
|
|
||||||
-- The display name, human friendlier than device_id and updatable
|
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
|
|
||||||
|
|
||||||
UNIQUE (localpart, device_id)
|
UNIQUE (localpart, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertDeviceSQL = "" +
|
const insertDeviceSQL = "" +
|
||||||
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5);" +
|
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" +
|
||||||
"SELECT last_insert_rowid() AS session_id"
|
" VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
|
||||||
|
const selectDevicesCountSQL = "" +
|
||||||
|
"SELECT COUNT(access_token) FROM device_devices"
|
||||||
|
|
||||||
const selectDeviceByTokenSQL = "" +
|
const selectDeviceByTokenSQL = "" +
|
||||||
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||||
|
@ -81,6 +70,7 @@ const deleteDevicesByLocalpartSQL = "" +
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
|
selectDevicesCountStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
selectDeviceByIDStmt *sql.Stmt
|
selectDeviceByIDStmt *sql.Stmt
|
||||||
selectDevicesByLocalpartStmt *sql.Stmt
|
selectDevicesByLocalpartStmt *sql.Stmt
|
||||||
|
@ -98,6 +88,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
||||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -129,8 +122,13 @@ func (s *devicesStatements) insertDevice(
|
||||||
) (*authtypes.Device, error) {
|
) (*authtypes.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
var sessionID int64
|
var sessionID int64
|
||||||
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
countStmt := common.TxStmt(txn, s.selectDevicesCountStmt)
|
||||||
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
|
insertStmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||||
|
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sessionID++
|
||||||
|
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &authtypes.Device{
|
return &authtypes.Device{
|
||||||
|
|
Loading…
Reference in a new issue