diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index a5fd5449a..71238b0e4 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -103,6 +103,8 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( return } + defer rows.Close() // nolint: errcheck + for rows.Next() { var alias string if err = rows.Scan(&alias); err != nil { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index ac593546a..d75abceec 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -30,7 +30,7 @@ import ( const stateDataSchema = ` CREATE TABLE IF NOT EXISTS roomserver_state_block ( - state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + state_block_nid INTEGER NOT NULL, event_type_nid INTEGER NOT NULL, event_state_key_nid INTEGER NOT NULL, event_nid INTEGER NOT NULL, @@ -43,10 +43,7 @@ const insertStateDataSQL = "" + " VALUES ($1, $2, $3, $4)" const selectNextStateBlockNIDSQL = ` - SELECT COALESCE(( - SELECT seq+1 AS state_block_nid FROM sqlite_sequence - WHERE name = 'roomserver_state_block'), 1 - ) AS state_block_nid +SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block ` // Bulk state lookup by numeric state block ID. @@ -98,11 +95,19 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { func (s *stateBlockStatements) bulkInsertStateData( ctx context.Context, txn *sql.Tx, - stateBlockNID types.StateBlockNID, entries []types.StateEntry, -) error { +) (types.StateBlockNID, error) { + if len(entries) == 0 { + return 0, nil + } + var stateBlockNID types.StateBlockNID + err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + if err != nil { + return 0, err + } + for _, entry := range entries { - _, err := common.TxStmt(txn, s.insertStateDataStmt).ExecContext( + _, err := txn.Stmt(s.insertStateDataStmt).ExecContext( ctx, int64(stateBlockNID), int64(entry.EventTypeNID), @@ -110,20 +115,10 @@ func (s *stateBlockStatements) bulkInsertStateData( int64(entry.EventNID), ) if err != nil { - return err + return 0, err } } - return nil -} - -func (s *stateBlockStatements) selectNextStateBlockNID( - ctx context.Context, - txn *sql.Tx, -) (types.StateBlockNID, error) { - var stateBlockNID int64 - selectStmt := common.TxStmt(txn, s.selectNextStateBlockNIDStmt) - err := selectStmt.QueryRowContext(ctx).Scan(&stateBlockNID) - return types.StateBlockNID(stateBlockNID), err + return stateBlockNID, nil } func (s *stateBlockStatements) bulkSelectStateBlockEntries( diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index e20e8aed7..a033e46e0 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -285,14 +285,10 @@ func (d *Database) AddState( ) (stateNID types.StateSnapshotNID, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { if len(state) > 0 { - var stateBlockNID types.StateBlockNID - stateBlockNID, err = d.statements.selectNextStateBlockNID(ctx, txn) + stateBlockNID, err := d.statements.bulkInsertStateData(ctx, txn, state) if err != nil { return err } - if err = d.statements.bulkInsertStateData(ctx, txn, stateBlockNID, state); err != nil { - return err - } stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) } stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs)