From 22a034dcba6e1437435011c5e0648972ad025140 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Mar 2022 15:05:42 +0000 Subject: [PATCH] Fix memory leaks with SQLite prepared statements (#2253) --- .../storage/sqlite3/event_state_keys_table.go | 1 + .../storage/sqlite3/event_types_table.go | 1 + roomserver/storage/sqlite3/events_table.go | 37 ++++++++++++------- roomserver/storage/sqlite3/rooms_table.go | 5 ++- .../storage/sqlite3/state_block_table.go | 5 ++- .../storage/sqlite3/state_snapshot_table.go | 5 ++- 6 files changed, 34 insertions(+), 20 deletions(-) diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index bf12d5b83..8af40024a 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -154,6 +154,7 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKey( if err != nil { return nil, err } + defer selectPrep.Close() stmt := sqlutil.TxStmt(txn, selectPrep) rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...) if err != nil { diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index f2c9c42fe..f794a3d0e 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -140,6 +140,7 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( if err != nil { return nil, err } + defer selectPrep.Close() stmt := sqlutil.TxStmt(txn, selectPrep) /////////////// diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 969a10ce5..2ab1151d5 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -198,11 +198,12 @@ func (s *eventStatements) BulkSelectStateEventByID( iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) - selectStmt, err := s.db.Prepare(selectOrig) + selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - selectStmt = sqlutil.TxStmt(txn, selectStmt) + defer selectPrep.Close() // nolint:errcheck + selectStmt := sqlutil.TxStmt(txn, selectPrep) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) @@ -266,11 +267,12 @@ func (s *eventStatements) BulkSelectStateEventByNID( } } selectOrig += " ORDER BY event_type_nid, event_state_key_nid ASC" - selectStmt, err := s.db.Prepare(selectOrig) + selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, fmt.Errorf("s.db.Prepare: %w", err) } - selectStmt = sqlutil.TxStmt(txn, selectStmt) + defer selectPrep.Close() // nolint:errcheck + selectStmt := sqlutil.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, params...) if err != nil { return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) @@ -307,11 +309,12 @@ func (s *eventStatements) BulkSelectStateAtEventByID( iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) - selectStmt, err := s.db.Prepare(selectOrig) + selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - selectStmt = sqlutil.TxStmt(txn, selectStmt) + defer selectPrep.Close() // nolint:errcheck + selectStmt := sqlutil.TxStmt(txn, selectPrep) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -390,10 +393,11 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( if err != nil { return nil, err } - selectPrep = sqlutil.TxStmt(txn, selectPrep) + defer selectPrep.Close() // nolint:errcheck + selectStmt := sqlutil.TxStmt(txn, selectPrep) ////////////// - rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) + rows, err := sqlutil.TxStmt(txn, selectStmt).QueryContext(ctx, iEventNIDs...) if err != nil { return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err) } @@ -441,6 +445,7 @@ func (s *eventStatements) BulkSelectEventReference( if err != nil { return nil, err } + defer selectPrep.Close() // nolint:errcheck /////////////// selectStmt := sqlutil.TxStmt(txn, selectPrep) @@ -471,11 +476,12 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - selectStmt, err := s.db.Prepare(selectOrig) + selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - selectStmt = sqlutil.TxStmt(txn, selectStmt) + defer selectPrep.Close() // nolint:errcheck + selectStmt := sqlutil.TxStmt(txn, selectPrep) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) @@ -526,11 +532,12 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e } else { selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) } - selectStmt, err := s.db.Prepare(selectOrig) + selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - selectStmt = sqlutil.TxStmt(txn, selectStmt) + defer selectPrep.Close() // nolint:errcheck + selectStmt := sqlutil.TxStmt(txn, selectPrep) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -560,6 +567,7 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, if err != nil { return 0, err } + defer sqlPrep.Close() err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result) if err != nil { return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) @@ -575,12 +583,13 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( if err != nil { return nil, err } - sqlPrep = sqlutil.TxStmt(txn, sqlPrep) + defer sqlPrep.Close() + sqlStmt := sqlutil.TxStmt(txn, sqlPrep) iEventNIDs := make([]interface{}, len(eventNIDs)) for i, v := range eventNIDs { iEventNIDs[i] = v } - rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...) + rows, err := sqlStmt.QueryContext(ctx, iEventNIDs...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 5413475e2..a81b78148 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -233,12 +233,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( if err != nil { return nil, err } - sqlPrep = sqlutil.TxStmt(txn, sqlPrep) + defer sqlPrep.Close() // nolint:errcheck + sqlStmt := sqlutil.TxStmt(txn, sqlPrep) iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v } - rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...) + rows, err := sqlStmt.QueryContext(ctx, iRoomNIDs...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index d51fc492d..3c829cdcd 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -108,11 +108,12 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( intfs[i] = int64(stateBlockNIDs[i]) } selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1) - selectStmt, err := s.db.Prepare(selectOrig) + selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - selectStmt = sqlutil.TxStmt(txn, selectStmt) + defer selectPrep.Close() // nolint:errcheck + selectStmt := sqlutil.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, intfs...) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 01df31e90..1f5e9ee3b 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -113,11 +113,12 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( nids[k] = v } selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) - selectStmt, err := s.db.Prepare(selectOrig) + selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - selectStmt = sqlutil.TxStmt(txn, selectStmt) + defer selectPrep.Close() // nolint:errcheck + selectStmt := sqlutil.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil {