From fb9243b4edd22f156705d9f5c10ccea4035b764e Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 12 Feb 2020 11:55:31 +0000 Subject: [PATCH] sqlite work --- common/sql.go | 15 +- docker/docker-compose.yml | 10 +- .../storage/sqlite3/event_json_table.go | 11 +- .../storage/sqlite3/event_state_keys_table.go | 62 +-- roomserver/storage/sqlite3/events_table.go | 40 +- .../storage/sqlite3/state_snapshot_table.go | 29 +- roomserver/storage/sqlite3/storage.go | 366 ++++++++++++------ 7 files changed, 299 insertions(+), 234 deletions(-) diff --git a/common/sql.go b/common/sql.go index 7ac9ac140..043de8cd0 100644 --- a/common/sql.go +++ b/common/sql.go @@ -30,11 +30,13 @@ type Transaction interface { // EndTransaction ends a transaction. // If the transaction succeeded then it is committed, otherwise it is rolledback. -func EndTransaction(txn Transaction, succeeded *bool) { +// You MUST check the error returned from this function to be sure that the transaction +// was applied correctly. For example, 'database is locked' errors in sqlite will happen here. +func EndTransaction(txn Transaction, succeeded *bool) error { if *succeeded { - txn.Commit() // nolint: errcheck + return txn.Commit() // nolint: errcheck } else { - txn.Rollback() // nolint: errcheck + return txn.Rollback() // nolint: errcheck } } @@ -47,7 +49,12 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { return } succeeded := false - defer EndTransaction(txn, &succeeded) + defer func() { + err2 := EndTransaction(txn, &succeeded) + if err == nil && err2 != nil { // failed to commit/rollback + err = err2 + } + }() err = fn(txn) if err != nil { diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 9cf67457c..d738ed3f0 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,13 +1,21 @@ version: "3.4" services: + riot: + image: vectorim/riot-web + networks: + - internal + ports: + - "8500:80" + monolith: container_name: dendrite_monolith hostname: monolith - entrypoint: ["bash", "./docker/services/monolith.sh"] + entrypoint: ["bash", "./docker/services/monolith.sh", "--config", "/etc/dendrite/dendrite.yaml"] build: ./ volumes: - ..:/build - ./build/bin:/build/bin + - ../cfg:/etc/dendrite networks: - internal depends_on: diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index f6738159d..4ccf16d3c 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -18,7 +18,6 @@ package sqlite3 import ( "context" "database/sql" - "fmt" "strings" "github.com/matrix-org/dendrite/common" @@ -79,21 +78,14 @@ type eventJSONPair struct { func (s *eventJSONStatements) bulkSelectEventJSON( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]eventJSONPair, error) { - /////////////// iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - /////////////// - rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) + rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil { - fmt.Println("bulkSelectEventJSON s.bulkSelectEventJSONStmt.QueryContext:", err) return nil, err } defer rows.Close() // nolint: errcheck @@ -108,7 +100,6 @@ func (s *eventJSONStatements) bulkSelectEventJSON( result := &results[i] var eventNID int64 if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { - fmt.Println("bulkSelectEventJSON rows.Scan:", err) return nil, err } result.EventNID = types.EventNID(eventNID) diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 5ce72809a..899845baf 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -21,7 +21,6 @@ import ( "fmt" "strings" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -41,11 +40,6 @@ const insertEventStateKeyNIDSQL = ` ON CONFLICT DO NOTHING; ` -const insertEventStateKeyNIDResultSQL = ` - SELECT event_state_key_nid FROM roomserver_event_state_keys - WHERE rowid = last_insert_rowid(); -` - const selectEventStateKeyNIDSQL = ` SELECT event_state_key_nid FROM roomserver_event_state_keys WHERE event_state_key = $1 @@ -66,12 +60,11 @@ const bulkSelectEventStateKeySQL = ` ` type eventStateKeyStatements struct { - db *sql.DB - insertEventStateKeyNIDStmt *sql.Stmt - insertEventStateKeyNIDResultStmt *sql.Stmt - selectEventStateKeyNIDStmt *sql.Stmt - bulkSelectEventStateKeyNIDStmt *sql.Stmt - bulkSelectEventStateKeyStmt *sql.Stmt + db *sql.DB + insertEventStateKeyNIDStmt *sql.Stmt + selectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyStmt *sql.Stmt } func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { @@ -82,7 +75,6 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { } return statementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, - {&s.insertEventStateKeyNIDResultStmt, insertEventStateKeyNIDResultSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, @@ -94,13 +86,10 @@ func (s *eventStateKeyStatements) insertEventStateKeyNID( ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 var err error - insertStmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt) - selectStmt := common.TxStmt(txn, s.insertEventStateKeyNIDResultStmt) - if _, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { - err = selectStmt.QueryRowContext(ctx).Scan(&eventStateKeyNID) - if err != nil { - fmt.Println("insertEventStateKeyNID selectStmt.QueryRowContext:", err) - } + var res sql.Result + insertStmt := txn.Stmt(s.insertEventStateKeyNIDStmt) + if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { + eventStateKeyNID, err = res.LastInsertId() } else { fmt.Println("insertEventStateKeyNID insertStmt.ExecContext:", err) } @@ -111,36 +100,22 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt) - fmt.Println("selectEventStateKeyNID for", eventStateKey) + stmt := txn.Stmt(s.selectEventStateKeyNIDStmt) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) - if err != nil { - fmt.Println("selectEventStateKeyNID stmt.QueryRowContext:", err) - } - fmt.Println("selectEventStateKeyNID returns", eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { - /////////////// iEventStateKeys := make([]interface{}, len(eventStateKeys)) for k, v := range eventStateKeys { iEventStateKeys[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1) - selectPrep, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - /////////////// - rows, err := common.TxStmt(txn, selectPrep).QueryContext( - ctx, iEventStateKeys..., - ) + rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...) if err != nil { - fmt.Println("bulkSelectEventStateKeyNID s.bulkSelectEventStateKeyNIDStmt.QueryContext:", err) return nil, err } defer rows.Close() // nolint: errcheck @@ -149,7 +124,6 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( var stateKey string var stateKeyNID int64 if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { - fmt.Println("bulkSelectEventStateKeyNID rows.Scan:", err) return nil, err } result[stateKey] = types.EventStateKeyNID(stateKeyNID) @@ -160,25 +134,14 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( func (s *eventStateKeyStatements) bulkSelectEventStateKey( ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { - /////////////// iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) for k, v := range eventStateKeyNIDs { iEventStateKeyNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - /////////////// - /*nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) - for i := range eventStateKeyNIDs { - nIDs[i] = int64(eventStateKeyNIDs[i]) - }*/ - rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventStateKeyNIDs...) + rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) if err != nil { - fmt.Println("bulkSelectEventStateKey s.bulkSelectEventStateKeyStmt.QueryContext:", err) return nil, err } defer rows.Close() // nolint: errcheck @@ -187,7 +150,6 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey( var stateKey string var stateKeyNID int64 if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { - fmt.Println("bulkSelectEventStateKey rows.Scan:", err) return nil, err } result[types.EventStateKeyNID(stateKeyNID)] = stateKey diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 6fa5fdb08..2d0fe2073 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -69,7 +69,7 @@ const bulkSelectStateAtEventByIDSQL = "" + " WHERE event_id IN ($1)" const updateEventStateSQL = "" + - "UPDATE roomserver_events SET state_snapshot_nid = $2 WHERE event_nid = $1" + "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2" const selectEventSentToOutputSQL = "" + "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" @@ -153,21 +153,18 @@ func (s *eventStatements) insertEvent( var eventNID int64 var stateNID int64 var err error + var res sql.Result insertStmt := common.TxStmt(txn, s.insertEventStmt) resultStmt := common.TxStmt(txn, s.insertEventResultStmt) - if _, err = insertStmt.ExecContext( + if res, err = insertStmt.ExecContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, ); err == nil { + a, b := res.LastInsertId() + fmt.Println("LastInsertId", a, b) err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID) - if err != nil { - fmt.Println("insertEvent HAS FAILED!", err) - } - } else { - fmt.Println("insertEvent HAS GONE WRONG!", err) } - fmt.Println("Event NID:", eventNID) - fmt.Println("State snapshot NID:", stateNID) + fmt.Println("INSERT event ID", eventID, "state snapshot NID:", stateNID, "event NID:", eventNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } @@ -192,7 +189,7 @@ func (s *eventStatements) bulkSelectStateEventByID( iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err } @@ -245,7 +242,7 @@ func (s *eventStatements) bulkSelectStateAtEventByID( iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err } @@ -287,10 +284,14 @@ func (s *eventStatements) updateEventState( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { updateStmt := common.TxStmt(txn, s.updateEventStateStmt) - _, err := updateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) + fmt.Println("=====================================") + fmt.Println(updateEventStateSQL, stateNID, eventNID) + res, err := updateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) if err != nil { fmt.Println("updateEventState s.updateEventStateStmt.ExecContext:", err) } + a, b := res.RowsAffected() + fmt.Println("Rows affected:", a, b) return err } @@ -336,14 +337,9 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - /////////////// + ////////////// - selectStmt := common.TxStmt(txn, selectPrep) - rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) + rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil { fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err) return nil, err @@ -389,7 +385,7 @@ func (s *eventStatements) bulkSelectEventReference( iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err } @@ -425,7 +421,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err } @@ -464,7 +460,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 4b9ebf7bf..6d9727d5c 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -22,7 +22,6 @@ import ( "strings" "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -36,13 +35,7 @@ const stateSnapshotSchema = ` const insertStateSQL = ` INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) - VALUES ($1, $2); -` - -const insertStateResultSQL = ` - SELECT state_snapshot_nid FROM roomserver_state_snapshots - WHERE rowid = last_insert_rowid(); -` + VALUES ($1, $2);` // Bulk state data NID lookup. // Sorting by state_snapshot_nid means we can use binary search over the result @@ -67,7 +60,6 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { return statementList{ {&s.insertStateStmt, insertStateSQL}, - {&s.insertStateResultStmt, insertStateResultSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, }.prepare(db) } @@ -79,15 +71,17 @@ func (s *stateSnapshotStatements) insertState( for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - insertStmt := common.TxStmt(txn, s.insertStateStmt) - resultStmt := common.TxStmt(txn, s.insertStateResultStmt) - if _, err = insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err == nil { - err = resultStmt.QueryRowContext(ctx).Scan(&stateNID) - if err != nil { - fmt.Println("insertState s.insertStateResultStmt.QueryRowContext:", err) + insertStmt := txn.Stmt(s.insertStateStmt) + //resultStmt := txn.Stmt(s.insertStateResultStmt) + fmt.Println(insertStateSQL, roomNID, nids) + if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err2 == nil { + lastRowID, err3 := res.LastInsertId() + if err3 != nil { + err = err3 } + stateNID = types.StateSnapshotNID(lastRowID) } else { - fmt.Println("insertState s.insertStateStmt.ExecContext:", err) + fmt.Println("insertState s.insertStateStmt.ExecContext:", err2) } return } @@ -101,7 +95,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( nids[k] = v } selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + selectStmt, err := txn.Prepare(selectOrig) if err != nil { return nil, err } @@ -112,7 +106,6 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( nids[i] = int64(stateNIDs[i]) } */ - selectStmt := common.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err) diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 0b43defde..132fa2513 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "net/url" + "time" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" @@ -54,8 +55,8 @@ func Open(dataSourceName string) (*Database, error) { return nil, err } //d.db.Exec("PRAGMA journal_mode=WAL;") - //d.db.Exec("PRAGMA parser_trace = true;") - d.db.SetMaxOpenConns(1) + //d.db.Exec("PRAGMA read_uncommitted = true;") + d.db.SetMaxOpenConns(2) if err = d.statements.prepare(d.db); err != nil { return nil, err } @@ -196,36 +197,56 @@ func (d *Database) assignStateKeyNID( // StateEntriesForEventIDs implements input.EventDatabase func (d *Database) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return d.statements.bulkSelectStateEventByID(ctx, nil, eventIDs) +) (se []types.StateEntry, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs) + return err + }) + return } // EventTypeNIDs implements state.RoomStateDatabase func (d *Database) EventTypeNIDs( ctx context.Context, eventTypes []string, -) (map[string]types.EventTypeNID, error) { - return d.statements.bulkSelectEventTypeNID(ctx, nil, eventTypes) +) (etnids map[string]types.EventTypeNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes) + return err + }) + return } // EventStateKeyNIDs implements state.RoomStateDatabase func (d *Database) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, -) (map[string]types.EventStateKeyNID, error) { - return d.statements.bulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) +) (esknids map[string]types.EventStateKeyNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) + return err + }) + return } // EventStateKeys implements query.RoomserverQueryAPIDatabase func (d *Database) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, -) (map[types.EventStateKeyNID]string, error) { - return d.statements.bulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs) +) (out map[types.EventStateKeyNID]string, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs) + return err + }) + return } // EventNIDs implements query.RoomserverQueryAPIDatabase func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return d.statements.bulkSelectEventNID(ctx, nil, eventIDs) +) (out map[string]types.EventNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs) + return err + }) + return } // Events implements input.EventDatabase @@ -235,12 +256,15 @@ func (d *Database) Events( var eventJSONs []eventJSONPair var err error results := make([]types.Event, len(eventNIDs)) + fmt.Println("pre txn") common.WithTransaction(d.db, func(txn *sql.Tx) error { + fmt.Println("in txn", txn) eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil || len(eventJSONs) == 0 { fmt.Println("d.statements.bulkSelectEventJSON:", err) return nil } + fmt.Println("selected txn") for i, eventJSON := range eventJSONs { result := &results[i] result.EventNID = eventJSON.EventNID @@ -252,6 +276,7 @@ func (d *Database) Events( } return nil }) + fmt.Println("post txn") if err != nil { return []types.Event{}, err } @@ -265,7 +290,9 @@ func (d *Database) AddState( stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { - common.WithTransaction(d.db, func(txn *sql.Tx) error { + fmt.Println("AddState INSERT STATE START", stateBlockNIDs) + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + fmt.Println("insert state txn created") if len(state) > 0 { stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx, txn) if err != nil { @@ -277,8 +304,10 @@ func (d *Database) AddState( stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) } stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs) - return nil + fmt.Println("AddState: completing txn", time.Now(), "err=", err) + return err }) + fmt.Println("AddState INSERT STATE END pkey=", stateNID, time.Now(), "err=", err) if err != nil { return 0, err } @@ -289,49 +318,77 @@ func (d *Database) AddState( func (d *Database) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - return d.statements.updateEventState(ctx, nil, eventNID, stateNID) + fmt.Println("SetState event NID:", eventNID, "state NID:", stateNID) + e := common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.statements.updateEventState(ctx, txn, eventNID, stateNID) + }) + fmt.Println("SetState finish", e) + return e } // StateAtEventIDs implements input.EventDatabase func (d *Database) StateAtEventIDs( ctx context.Context, eventIDs []string, -) ([]types.StateAtEvent, error) { - return d.statements.bulkSelectStateAtEventByID(ctx, nil, eventIDs) +) (se []types.StateAtEvent, err error) { + common.WithTransaction(d.db, func(txn *sql.Tx) error { + se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs) + return err + }) + return } // StateBlockNIDs implements state.RoomStateDatabase func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, -) ([]types.StateBlockNIDList, error) { - return d.statements.bulkSelectStateBlockNIDs(ctx, nil, stateNIDs) +) (sl []types.StateBlockNIDList, err error) { + fmt.Println("StateBlockNIDs SELECT STATE START", stateNIDs) + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + fmt.Println(" in txn") + sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs) + return err + }) + fmt.Println("StateBlockNIDs SELECT STATE END", sl) + return } // StateEntries implements state.RoomStateDatabase func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectStateBlockEntries(ctx, nil, stateBlockNIDs) +) (sel []types.StateEntryList, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) + return err + }) + return } // SnapshotNIDFromEventID implements state.RoomStateDatabase func (d *Database) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (stateNID types.StateSnapshotNID, err error) { - _, stateNID, err = d.statements.selectEvent(ctx, nil, eventID) + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + _, stateNID, err = d.statements.selectEvent(ctx, txn, eventID) + return err + }) return } // EventIDs implements input.RoomEventDatabase func (d *Database) EventIDs( ctx context.Context, eventNIDs []types.EventNID, -) (map[types.EventNID]string, error) { - return d.statements.bulkSelectEventID(ctx, nil, eventNIDs) +) (out map[types.EventNID]string, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs) + return err + }) + return } // GetLatestEventsForUpdate implements input.EventDatabase func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, ) (types.RoomRecentEventsUpdater, error) { + fmt.Println("=============== GetLatestEventsForUpdate BEGIN TXN") txn, err := d.db.Begin() if err != nil { return nil, err @@ -355,8 +412,15 @@ func (d *Database) GetLatestEventsForUpdate( return nil, err } } + fmt.Println("GetLatestEventsForUpdate returning updater") + + // FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get + // 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here) + // so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so + // we fail fast if someone tries to use the underlying txn object. + txn.Commit() return &roomRecentEventsUpdater{ - transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, }, nil } @@ -398,24 +462,33 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN // StorePreviousEvents implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return err + err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return err + } } - } - return nil + return nil + }) + return err } // IsReferenced implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - return true, nil - } - if err == sql.ErrNoRows { - return false, nil - } - return false, err +func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) { + fmt.Println("[[TXN]] IsReferenced") + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + res = true + err = nil + } + if err == sql.ErrNoRows { + res = false + err = nil + } + return err + }) + return } // SetLatestEvents implements types.RoomRecentEventsUpdater @@ -423,38 +496,58 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, currentStateSnapshotNID types.StateSnapshotNID, ) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - // TODO: transaction was removed here - is this wise? - return u.d.statements.updateLatestEventNIDs(u.ctx, nil, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + }) + return err } // HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { +func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) { // TODO: transaction was removed here - is this wise? - return u.d.statements.selectEventSentToOutput(u.ctx, nil, eventNID) + fmt.Println("[[TXN]] HasEventBeenSent") + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID) + return err + }) + return } // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { // TODO: transaction was removed here - is this wise? - return u.d.statements.updateEventSentToOutput(u.ctx, nil, eventNID) + fmt.Println("[[TXN]] updateEventSentToOutput") + err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID) + }) + return err } -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) { // TODO: transaction was removed here - is this wise? - return u.d.membershipUpdaterTxn(u.ctx, nil, u.roomNID, targetUserNID) + fmt.Println("[[TXN]] membershipUpdaterTxn") + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID) + return err + }) + return } // RoomNID implements query.RoomserverQueryAPIDB -func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { - roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) - if err == sql.ErrNoRows { - return 0, nil - } - return roomNID, err +func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + roomNID = 0 + err = nil + } + return err + }) + return } // LatestEventIDs implements query.RoomserverQueryAPIDatabase @@ -532,12 +625,14 @@ func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, ) (types.MembershipUpdater, error) { txn, err := d.db.Begin() + fmt.Println("=== UPDATER TXN START ====") if err != nil { return nil, err } succeeded := false defer func() { if !succeeded { + fmt.Println("=== UPDATER TXN ROLLBACK ====") txn.Rollback() // nolint: errcheck } }() @@ -606,92 +701,99 @@ func (u *membershipUpdater) IsLeave() bool { } // SetToInvite implements types.MembershipUpdater -func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) - if err != nil { - return false, err - } - inserted, err := u.d.statements.insertInviteEvent( - u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), - ) - if err != nil { - return false, err - } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, - ); err != nil { - return false, err +func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender()) + if err != nil { + return err } - } - return inserted, nil + inserted, err = u.d.statements.insertInviteEvent( + u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return err + } + if u.membership != membershipStateInvite { + if err = u.d.statements.updateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + ); err != nil { + return err + } + } + return nil + }) + return } // SetToJoin implements types.MembershipUpdater -func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { - var inviteEventIDs []string - - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, err - } - - // If this is a join event update, there is no invite to update - if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) +func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) if err != nil { - return nil, err + return err } - } - // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, err - } - - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], - ); err != nil { - return nil, err + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.ctx, txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return err + } } - } - return inviteEventIDs, nil + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return err + } + + if u.membership != membershipStateJoin || isUpdate { + if err = u.d.statements.updateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateJoin, nIDs[eventID], + ); err != nil { + return err + } + } + return nil + }) + + return } // SetToLeave implements types.MembershipUpdater -func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, err - } - inviteEventIDs, err := u.d.statements.updateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return nil, err - } - - // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, err - } - - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { - return nil, err +func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) + if err != nil { + return err } - } - return inviteEventIDs, nil + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.ctx, txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return err + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return err + } + + if u.membership != membershipStateLeaveOrBan { + if err = u.d.statements.updateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return err + } + } + return nil + }) + return } // GetMembership implements query.RoomserverQueryAPIDB @@ -762,10 +864,16 @@ type transaction struct { // Commit implements types.Transaction func (t *transaction) Commit() error { + if t.txn == nil { + return nil + } return t.txn.Commit() } // Rollback implements types.Transaction func (t *transaction) Rollback() error { + if t.txn == nil { + return nil + } return t.txn.Rollback() }