Associate transactions with session IDs instead of device IDs (#789)

This commit is contained in:
Alex Chen 2019-08-24 00:55:40 +08:00 committed by GitHub
parent 5eb63f1d1e
commit 43308d2f3f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 55 additions and 39 deletions

View file

@ -21,5 +21,9 @@ type Device struct {
// The access_token granted to this device. // The access_token granted to this device.
// This uniquely identifies the device from all other devices and clients. // This uniquely identifies the device from all other devices and clients.
AccessToken string AccessToken string
// The unique ID of the session identified by the access token.
// Can be used as a secure substitution in places where data needs to be
// associated with access tokens.
SessionID int64
// TODO: display name, last used timestamp, keys, etc // TODO: display name, last used timestamp, keys, etc
} }

View file

@ -27,11 +27,19 @@ import (
) )
const devicesSchema = ` const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices. -- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices ( CREATE TABLE IF NOT EXISTS device_devices (
-- The access token granted to this device. This has to be the primary key -- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request. -- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY, access_token TEXT NOT NULL PRIMARY KEY,
-- The auto-allocated unique ID of the session identified by the access token.
-- This can be used as a secure substitution of the access token in situations
-- where data is associated with access tokens (e.g. transaction storage),
-- so we don't have to store users' access tokens everywhere.
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally. -- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user. -- access_tokens will be clobbered based on the device ID for a user.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
@ -51,10 +59,11 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
` `
const insertDeviceSQL = "" + const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" + const selectDeviceByTokenSQL = "" +
"SELECT device_id, localpart FROM device_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
@ -120,14 +129,16 @@ func (s *devicesStatements) insertDevice(
displayName *string, displayName *string,
) (*authtypes.Device, error) { ) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := common.TxStmt(txn, s.insertDeviceStmt) stmt := common.TxStmt(txn, s.insertDeviceStmt)
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil { if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
return nil, err return nil, err
} }
return &authtypes.Device{ return &authtypes.Device{
ID: id, ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken, AccessToken: accessToken,
SessionID: sessionID,
}, nil }, nil
} }
@ -161,7 +172,7 @@ func (s *devicesStatements) selectDeviceByToken(
var dev authtypes.Device var dev authtypes.Device
var localpart string var localpart string
stmt := s.selectDeviceByTokenStmt stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart) err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil { if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken dev.AccessToken = accessToken

View file

@ -60,18 +60,18 @@ func SendEvent(
return *resErr return *resErr
} }
var txnAndDeviceID *api.TransactionID var txnAndSessionID *api.TransactionID
if txnID != nil { if txnID != nil {
txnAndDeviceID = &api.TransactionID{ txnAndSessionID = &api.TransactionID{
TransactionID: *txnID, TransactionID: *txnID,
DeviceID: device.ID, SessionID: device.SessionID,
} }
} }
// pass the new event to the roomserver and receive the correct event ID // pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded // event ID in case of duplicate transaction is discarded
eventID, err := producer.SendEvents( eventID, err := producer.SendEvents(
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndDeviceID, req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndSessionID,
) )
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)

View file

@ -75,9 +75,9 @@ type InputRoomEvent struct {
} }
// TransactionID contains the transaction ID sent by a client when sending an // TransactionID contains the transaction ID sent by a client when sending an
// event, along with the ID of that device. // event, along with the ID of the client session.
type TransactionID struct { type TransactionID struct {
DeviceID string `json:"device_id"` SessionID int64 `json:"session_id"`
TransactionID string `json:"id"` TransactionID string `json:"id"`
} }

View file

@ -32,7 +32,7 @@ type RoomEventDatabase interface {
StoreEvent( StoreEvent(
ctx context.Context, ctx context.Context,
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
txnAndDeviceID *api.TransactionID, txnAndSessionID *api.TransactionID,
authEventNIDs []types.EventNID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) ) (types.RoomNID, types.StateAtEvent, error)
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
@ -67,7 +67,7 @@ type RoomEventDatabase interface {
// Returns an empty string if no such event exists. // Returns an empty string if no such event exists.
GetTransactionEventID( GetTransactionEventID(
ctx context.Context, transactionID string, ctx context.Context, transactionID string,
deviceID string, userID string, sessionID int64, userID string,
) (string, error) ) (string, error)
} }
@ -100,7 +100,7 @@ func processRoomEvent(
if input.TransactionID != nil { if input.TransactionID != nil {
tdID := input.TransactionID tdID := input.TransactionID
eventID, err = db.GetTransactionEventID( eventID, err = db.GetTransactionEventID(
ctx, tdID.TransactionID, tdID.DeviceID, input.Event.Sender(), ctx, tdID.TransactionID, tdID.SessionID, input.Event.Sender(),
) )
// On error OR event with the transaction already processed/processesing // On error OR event with the transaction already processed/processesing
if err != nil || eventID != "" { if err != nil || eventID != "" {

View file

@ -47,7 +47,7 @@ func Open(dataSourceName string) (*Database, error) {
// StoreEvent implements input.EventDatabase // StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event, ctx context.Context, event gomatrixserverlib.Event,
txnAndDeviceID *api.TransactionID, authEventNIDs []types.EventNID, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) { ) (types.RoomNID, types.StateAtEvent, error) {
var ( var (
roomNID types.RoomNID roomNID types.RoomNID
@ -58,10 +58,10 @@ func (d *Database) StoreEvent(
err error err error
) )
if txnAndDeviceID != nil { if txnAndSessionID != nil {
if err = d.statements.insertTransaction( if err = d.statements.insertTransaction(
ctx, txnAndDeviceID.TransactionID, ctx, txnAndSessionID.TransactionID,
txnAndDeviceID.DeviceID, event.Sender(), event.EventID(), txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil { ); err != nil {
return 0, types.StateAtEvent{}, err return 0, types.StateAtEvent{}, err
} }
@ -322,9 +322,9 @@ func (d *Database) GetLatestEventsForUpdate(
// GetTransactionEventID implements input.EventDatabase // GetTransactionEventID implements input.EventDatabase
func (d *Database) GetTransactionEventID( func (d *Database) GetTransactionEventID(
ctx context.Context, transactionID string, ctx context.Context, transactionID string,
deviceID string, userID string, sessionID int64, userID string,
) (string, error) { ) (string, error) {
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, deviceID, userID) eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }

View file

@ -23,8 +23,8 @@ const transactionsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_transactions ( CREATE TABLE IF NOT EXISTS roomserver_transactions (
-- The transaction ID of the event. -- The transaction ID of the event.
transaction_id TEXT NOT NULL, transaction_id TEXT NOT NULL,
-- The device ID of the originating transaction. -- The session ID of the originating transaction.
device_id TEXT NOT NULL, session_id BIGINT NOT NULL,
-- User ID of the sender who authored the event -- User ID of the sender who authored the event
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
-- Event ID corresponding to the transaction -- Event ID corresponding to the transaction
@ -32,16 +32,16 @@ CREATE TABLE IF NOT EXISTS roomserver_transactions (
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
-- A transaction ID is unique for a user and device -- A transaction ID is unique for a user and device
-- This automatically creates an index. -- This automatically creates an index.
PRIMARY KEY (transaction_id, device_id, user_id) PRIMARY KEY (transaction_id, session_id, user_id)
); );
` `
const insertTransactionSQL = "" + const insertTransactionSQL = "" +
"INSERT INTO roomserver_transactions (transaction_id, device_id, user_id, event_id)" + "INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)" +
" VALUES ($1, $2, $3, $4)" " VALUES ($1, $2, $3, $4)"
const selectTransactionEventIDSQL = "" + const selectTransactionEventIDSQL = "" +
"SELECT event_id FROM roomserver_transactions" + "SELECT event_id FROM roomserver_transactions" +
" WHERE transaction_id = $1 AND device_id = $2 AND user_id = $3" " WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3"
type transactionStatements struct { type transactionStatements struct {
insertTransactionStmt *sql.Stmt insertTransactionStmt *sql.Stmt
@ -63,12 +63,12 @@ func (s *transactionStatements) prepare(db *sql.DB) (err error) {
func (s *transactionStatements) insertTransaction( func (s *transactionStatements) insertTransaction(
ctx context.Context, ctx context.Context,
transactionID string, transactionID string,
deviceID string, sessionID int64,
userID string, userID string,
eventID string, eventID string,
) (err error) { ) (err error) {
_, err = s.insertTransactionStmt.ExecContext( _, err = s.insertTransactionStmt.ExecContext(
ctx, transactionID, deviceID, userID, eventID, ctx, transactionID, sessionID, userID, eventID,
) )
return return
} }
@ -76,11 +76,11 @@ func (s *transactionStatements) insertTransaction(
func (s *transactionStatements) selectTransactionEventID( func (s *transactionStatements) selectTransactionEventID(
ctx context.Context, ctx context.Context,
transactionID string, transactionID string,
deviceID string, sessionID int64,
userID string, userID string,
) (eventID string, err error) { ) (eventID string, err error) {
err = s.selectTransactionEventIDStmt.QueryRowContext( err = s.selectTransactionEventIDStmt.QueryRowContext(
ctx, transactionID, deviceID, userID, ctx, transactionID, sessionID, userID,
).Scan(&eventID) ).Scan(&eventID)
return return
} }

View file

@ -54,7 +54,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
-- if there is no delta. -- if there is no delta.
add_state_ids TEXT[], add_state_ids TEXT[],
remove_state_ids TEXT[], remove_state_ids TEXT[],
device_id TEXT, -- The local device that sent the event, if any session_id BIGINT, -- The client session that sent the event, if any
transaction_id TEXT -- The transaction id used to send the event, if any transaction_id TEXT -- The transaction id used to send the event, if any
); );
-- for event selection -- for event selection
@ -63,14 +63,14 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" + "INSERT INTO syncapi_output_room_events (" +
"room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, device_id, transaction_id" + "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id" ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)" "SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT id, event_json, device_id, transaction_id FROM syncapi_output_room_events" + "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
@ -221,9 +221,10 @@ func (s *outputRoomEventsStatements) insertEvent(
event *gomatrixserverlib.Event, addState, removeState []string, event *gomatrixserverlib.Event, addState, removeState []string,
transactionID *api.TransactionID, transactionID *api.TransactionID,
) (streamPos int64, err error) { ) (streamPos int64, err error) {
var deviceID, txnID *string var txnID *string
var sessionID *int64
if transactionID != nil { if transactionID != nil {
deviceID = &transactionID.DeviceID sessionID = &transactionID.SessionID
txnID = &transactionID.TransactionID txnID = &transactionID.TransactionID
} }
@ -246,7 +247,7 @@ func (s *outputRoomEventsStatements) insertEvent(
containsURL, containsURL,
pq.StringArray(addState), pq.StringArray(addState),
pq.StringArray(removeState), pq.StringArray(removeState),
deviceID, sessionID,
txnID, txnID,
).Scan(&streamPos) ).Scan(&streamPos)
return return
@ -296,11 +297,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
var ( var (
streamPos int64 streamPos int64
eventBytes []byte eventBytes []byte
deviceID *string sessionID *int64
txnID *string txnID *string
transactionID *api.TransactionID transactionID *api.TransactionID
) )
if err := rows.Scan(&streamPos, &eventBytes, &deviceID, &txnID); err != nil { if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
@ -309,9 +310,9 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
return nil, err return nil, err
} }
if deviceID != nil && txnID != nil { if sessionID != nil && txnID != nil {
transactionID = &api.TransactionID{ transactionID = &api.TransactionID{
DeviceID: *deviceID, SessionID: *sessionID,
TransactionID: *txnID, TransactionID: *txnID,
} }
} }

View file

@ -893,7 +893,7 @@ func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrix
for i := 0; i < len(in); i++ { for i := 0; i < len(in); i++ {
out[i] = in[i].Event out[i] = in[i].Event
if device != nil && in[i].transactionID != nil { if device != nil && in[i].transactionID != nil {
if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID { if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID {
err := out[i].SetUnsignedField( err := out[i].SetUnsignedField(
"transaction_id", in[i].transactionID.TransactionID, "transaction_id", in[i].transactionID.TransactionID,
) )