diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index ff6ddf3da..b1fae2b66 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -67,9 +67,9 @@ type mediaStatements struct { selectMediaStmt *sql.Stmt } -func (s *mediaStatements) prepare(db *sql.DB) (err error) { +func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(mediaSchema) if err != nil { diff --git a/mediaapi/storage/sqlite3/sql.go b/mediaapi/storage/sqlite3/sql.go index 9cd78b8ee..cc795143c 100644 --- a/mediaapi/storage/sqlite3/sql.go +++ b/mediaapi/storage/sqlite3/sql.go @@ -17,6 +17,8 @@ package sqlite3 import ( "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) type statements struct { @@ -24,11 +26,11 @@ type statements struct { thumbnail thumbnailStatements } -func (s *statements) prepare(db *sql.DB) (err error) { - if err = s.media.prepare(db); err != nil { +func (s *statements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) { + if err = s.media.prepare(db, writer); err != nil { return } - if err = s.thumbnail.prepare(db); err != nil { + if err = s.thumbnail.prepare(db, writer); err != nil { return } diff --git a/mediaapi/storage/sqlite3/storage.go b/mediaapi/storage/sqlite3/storage.go index a1e7fec7d..95dce3851 100644 --- a/mediaapi/storage/sqlite3/storage.go +++ b/mediaapi/storage/sqlite3/storage.go @@ -31,16 +31,19 @@ import ( type Database struct { statements statements db *sql.DB + writer sqlutil.TransactionWriter } // Open opens a postgres database. func Open(dbProperties *config.DatabaseOptions) (*Database, error) { - var d Database + d := Database{ + writer: sqlutil.NewTransactionWriter(), + } var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - if err = d.statements.prepare(d.db); err != nil { + if err = d.statements.prepare(d.db, d.writer); err != nil { return nil, err } return &d, nil diff --git a/mediaapi/storage/sqlite3/thumbnail_table.go b/mediaapi/storage/sqlite3/thumbnail_table.go index 432a1590c..d0aa312c0 100644 --- a/mediaapi/storage/sqlite3/thumbnail_table.go +++ b/mediaapi/storage/sqlite3/thumbnail_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -57,16 +58,20 @@ SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method ` type thumbnailStatements struct { + db *sql.DB + writer sqlutil.TransactionWriter insertThumbnailStmt *sql.Stmt selectThumbnailStmt *sql.Stmt selectThumbnailsStmt *sql.Stmt } -func (s *thumbnailStatements) prepare(db *sql.DB) (err error) { +func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) { _, err = db.Exec(thumbnailSchema) if err != nil { return } + s.db = db + s.writer = writer return statementList{ {&s.insertThumbnailStmt, insertThumbnailSQL}, @@ -79,18 +84,21 @@ func (s *thumbnailStatements) insertThumbnail( ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, ) error { thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - _, err := s.insertThumbnailStmt.ExecContext( - ctx, - thumbnailMetadata.MediaMetadata.MediaID, - thumbnailMetadata.MediaMetadata.Origin, - thumbnailMetadata.MediaMetadata.ContentType, - thumbnailMetadata.MediaMetadata.FileSizeBytes, - thumbnailMetadata.MediaMetadata.CreationTimestamp, - thumbnailMetadata.ThumbnailSize.Width, - thumbnailMetadata.ThumbnailSize.Height, - thumbnailMetadata.ThumbnailSize.ResizeMethod, - ) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt) + _, err := stmt.ExecContext( + ctx, + thumbnailMetadata.MediaMetadata.MediaID, + thumbnailMetadata.MediaMetadata.Origin, + thumbnailMetadata.MediaMetadata.ContentType, + thumbnailMetadata.MediaMetadata.FileSizeBytes, + thumbnailMetadata.MediaMetadata.CreationTimestamp, + thumbnailMetadata.ThumbnailSize.Width, + thumbnailMetadata.ThumbnailSize.Height, + thumbnailMetadata.ThumbnailSize.ResizeMethod, + ) + return err + }) } func (s *thumbnailStatements) selectThumbnail(