mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-01-18 18:04:27 -06:00
Fix accounts DB, device DB
This commit is contained in:
parent
db192382c6
commit
99b32109ba
|
@ -28,13 +28,11 @@ CREATE TABLE IF NOT EXISTS account_filter (
|
|||
-- The filter
|
||||
filter TEXT NOT NULL,
|
||||
-- The ID
|
||||
id SERIAL,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The localpart of the Matrix user ID associated to this filter
|
||||
localpart TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(id, localpart),
|
||||
|
||||
UNIQUE (id)
|
||||
UNIQUE (id, localpart)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
|
||||
|
@ -49,10 +47,14 @@ const selectFilterIDByContentSQL = "" +
|
|||
const insertFilterSQL = "" +
|
||||
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
|
||||
|
||||
const selectLastInsertedFilterIDSQL = "" +
|
||||
"SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
|
||||
|
||||
type filterStatements struct {
|
||||
selectFilterStmt *sql.Stmt
|
||||
selectFilterIDByContentStmt *sql.Stmt
|
||||
insertFilterStmt *sql.Stmt
|
||||
selectFilterStmt *sql.Stmt
|
||||
selectLastInsertedFilterIDStmt *sql.Stmt
|
||||
selectFilterIDByContentStmt *sql.Stmt
|
||||
insertFilterStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
||||
|
@ -63,6 +65,9 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
|||
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -123,7 +128,12 @@ func (s *filterStatements) insertFilter(
|
|||
}
|
||||
|
||||
// Otherwise insert the filter and return the new ID
|
||||
err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart).
|
||||
Scan(&filterID)
|
||||
if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil {
|
||||
return "", err
|
||||
}
|
||||
row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx)
|
||||
if err := row.Scan(&filterID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -32,34 +32,23 @@ const devicesSchema = `
|
|||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS device_devices (
|
||||
-- The access token granted to this device. This has to be the primary key
|
||||
-- so we can distinguish which device is making a given request.
|
||||
access_token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The auto-allocated unique ID of the session identified by the access token.
|
||||
-- This can be used as a secure substitution of the access token in situations
|
||||
-- where data is associated with access tokens (e.g. transaction storage),
|
||||
-- so we don't have to store users' access tokens everywhere.
|
||||
session_id BIGINT,
|
||||
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
|
||||
-- access_tokens will be clobbered based on the device ID for a user.
|
||||
device_id TEXT NOT NULL,
|
||||
-- The Matrix user ID localpart for this device. This is preferable to storing the full user_id
|
||||
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
|
||||
-- migration to different domain names easier.
|
||||
localpart TEXT NOT NULL,
|
||||
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The display name, human friendlier than device_id and updatable
|
||||
access_token TEXT PRIMARY KEY,
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
|
||||
|
||||
UNIQUE (localpart, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5);" +
|
||||
"SELECT last_insert_rowid() AS session_id"
|
||||
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const selectDevicesCountSQL = "" +
|
||||
"SELECT COUNT(access_token) FROM device_devices"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
|
@ -81,6 +70,7 @@ const deleteDevicesByLocalpartSQL = "" +
|
|||
|
||||
type devicesStatements struct {
|
||||
insertDeviceStmt *sql.Stmt
|
||||
selectDevicesCountStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
selectDeviceByIDStmt *sql.Stmt
|
||||
selectDevicesByLocalpartStmt *sql.Stmt
|
||||
|
@ -98,6 +88,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -129,8 +122,13 @@ func (s *devicesStatements) insertDevice(
|
|||
) (*authtypes.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
|
||||
countStmt := common.TxStmt(txn, s.selectDevicesCountStmt)
|
||||
insertStmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID++
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authtypes.Device{
|
||||
|
|
Loading…
Reference in a new issue