From 0cbb13b93481e68d00feec1a94ce74987662c027 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 26 May 2020 17:11:34 +0100 Subject: [PATCH] Convert transactions table --- roomserver/storage/postgres/sql.go | 1 - roomserver/storage/postgres/storage.go | 22 +++++++------------ .../storage/postgres/transactions_table.go | 17 ++++++++------ roomserver/storage/shared/storage.go | 13 +++++++++++ roomserver/storage/sqlite3/sql.go | 1 - roomserver/storage/sqlite3/storage.go | 20 ++++++----------- .../storage/sqlite3/transactions_table.go | 19 ++++++++-------- roomserver/storage/tables/interface.go | 5 +++++ 8 files changed, 53 insertions(+), 45 deletions(-) diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go index 3daaa9225..914f269c5 100644 --- a/roomserver/storage/postgres/sql.go +++ b/roomserver/storage/postgres/sql.go @@ -44,7 +44,6 @@ func (s *statements) prepare(db *sql.DB) error { s.roomAliasesStatements.prepare, s.inviteStatements.prepare, s.membershipStatements.prepare, - s.transactionStatements.prepare, } { if err = prepare(db); err != nil { return err diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 5c3834409..bfb84fb71 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -41,6 +41,7 @@ type Database struct { eventStateKeys tables.EventStateKeys eventJSON tables.EventJSON rooms tables.Rooms + transactions tables.Transactions db *sql.DB } @@ -74,6 +75,10 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.transactions, err = NewPostgresTransactionsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventTypesTable: d.eventTypes, @@ -81,6 +86,7 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, EventJSONTable: d.eventJSON, EventsTable: d.events, RoomsTable: d.rooms, + TransactionsTable: d.transactions, } return &d, nil } @@ -100,8 +106,8 @@ func (d *Database) StoreEvent( ) if txnAndSessionID != nil { - if err = d.statements.insertTransaction( - ctx, txnAndSessionID.TransactionID, + if err = d.transactions.InsertTransaction( + ctx, nil, txnAndSessionID.TransactionID, txnAndSessionID.SessionID, event.Sender(), event.EventID(), ); err != nil { return 0, types.StateAtEvent{}, err @@ -349,18 +355,6 @@ func (d *Database) GetLatestEventsForUpdate( }, nil } -// GetTransactionEventID implements input.EventDatabase -func (d *Database) GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, -) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID) - if err == sql.ErrNoRows { - return "", nil - } - return eventID, err -} - type roomRecentEventsUpdater struct { transaction d *Database diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go index 87c1cacae..7f7ef76ac 100644 --- a/roomserver/storage/postgres/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -18,6 +18,8 @@ package postgres import ( "context" "database/sql" + + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const transactionsSchema = ` @@ -51,20 +53,21 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func (s *transactionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(transactionsSchema) +func NewPostgresTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} + _, err := db.Exec(transactionsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, }.prepare(db) } -func (s *transactionStatements) insertTransaction( - ctx context.Context, +func (s *transactionStatements) InsertTransaction( + ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, @@ -76,7 +79,7 @@ func (s *transactionStatements) insertTransaction( return } -func (s *transactionStatements) selectTransactionEventID( +func (s *transactionStatements) SelectTransactionEventID( ctx context.Context, transactionID string, sessionID int64, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 06627d47c..51cab9fe4 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -17,6 +17,7 @@ type Database struct { EventTypesTable tables.EventTypes EventStateKeysTable tables.EventStateKeys RoomsTable tables.Rooms + TransactionsTable tables.Transactions } // EventTypeNIDs implements state.RoomStateDatabase @@ -147,3 +148,15 @@ func (d *Database) GetRoomVersionForRoomNID( ctx, nil, roomNID, ) } + +// GetTransactionEventID implements input.EventDatabase +func (d *Database) GetTransactionEventID( + ctx context.Context, transactionID string, + sessionID int64, userID string, +) (string, error) { + eventID, err := d.TransactionsTable.SelectTransactionEventID(ctx, transactionID, sessionID, userID) + if err == sql.ErrNoRows { + return "", nil + } + return eventID, err +} diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go index 20fcec469..fe899174a 100644 --- a/roomserver/storage/sqlite3/sql.go +++ b/roomserver/storage/sqlite3/sql.go @@ -44,7 +44,6 @@ func (s *statements) prepare(db *sql.DB) error { s.roomAliasesStatements.prepare, s.inviteStatements.prepare, s.membershipStatements.prepare, - s.transactionStatements.prepare, } { if err = prepare(db); err != nil { return err diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 3759e03f7..15a32274f 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -42,6 +42,7 @@ type Database struct { eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys rooms tables.Rooms + transactions tables.Transactions db *sql.DB } @@ -94,12 +95,17 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + d.transactions, err = NewSqliteTransactionsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ EventsTable: d.events, EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, EventJSONTable: d.eventJSON, RoomsTable: d.rooms, + TransactionsTable: d.transactions, } return &d, nil } @@ -120,7 +126,7 @@ func (d *Database) StoreEvent( err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { if txnAndSessionID != nil { - if err = d.statements.insertTransaction( + if err = d.transactions.InsertTransaction( ctx, txn, txnAndSessionID.TransactionID, txnAndSessionID.SessionID, event.Sender(), event.EventID(), ); err != nil { @@ -402,18 +408,6 @@ func (d *Database) GetLatestEventsForUpdate( }, nil } -// GetTransactionEventID implements input.EventDatabase -func (d *Database) GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, -) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID) - if err == sql.ErrNoRows { - return "", nil - } - return eventID, err -} - type roomRecentEventsUpdater struct { transaction d *Database diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index d22c73845..37ea15c0f 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const transactionsSchema = ` @@ -46,19 +47,20 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func (s *transactionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(transactionsSchema) +func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} + _, err := db.Exec(transactionsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, }.prepare(db) } -func (s *transactionStatements) insertTransaction( +func (s *transactionStatements) InsertTransaction( ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, @@ -72,14 +74,13 @@ func (s *transactionStatements) insertTransaction( return } -func (s *transactionStatements) selectTransactionEventID( - ctx context.Context, txn *sql.Tx, +func (s *transactionStatements) SelectTransactionEventID( + ctx context.Context, transactionID string, sessionID int64, userID string, ) (eventID string, err error) { - stmt := internal.TxStmt(txn, s.selectTransactionEventIDStmt) - err = stmt.QueryRowContext( + err = s.selectTransactionEventIDStmt.QueryRowContext( ctx, transactionID, sessionID, userID, ).Scan(&eventID) return diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 90fac605e..026e95b44 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -65,3 +65,8 @@ type Rooms interface { SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error) SelectRoomVersionForRoomNID(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) } + +type Transactions interface { + InsertTransaction(ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, eventID string) error + SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error) +}