Database-wide TransactionWriter

This commit is contained in:
Neil Alexander 2020-08-17 18:03:53 +01:00
parent e571e196ce
commit f91f309f4b
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
55 changed files with 193 additions and 154 deletions

View file

@ -75,9 +75,9 @@ type eventsStatements struct {
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
}
func (s *eventsStatements) prepare(db *sql.DB) (err error) {
func (s *eventsStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(appserviceEventsSchema)
if err != nil {
return

View file

@ -41,7 +41,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if result.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
writer := sqlutil.NewTransactionWriter()
if err = result.prepare(writer); err != nil {
return nil, err
}
if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil {
@ -50,12 +51,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
return &result, nil
}
func (d *Database) prepare() error {
if err := d.events.prepare(d.db); err != nil {
func (d *Database) prepare(writer *sqlutil.TransactionWriter) error {
if err := d.events.prepare(d.db, writer); err != nil {
return err
}
return d.txnID.prepare(d.db)
return d.txnID.prepare(d.db, writer)
}
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database

View file

@ -42,9 +42,9 @@ type txnStatements struct {
selectTxnIDStmt *sql.Stmt
}
func (s *txnStatements) prepare(db *sql.DB) (err error) {
func (s *txnStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(txnIDSchema)
if err != nil {
return

View file

@ -93,10 +93,10 @@ type currentRoomStateStatements struct {
selectKnownUsersStmt *sql.Stmt
}
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
func NewSqliteCurrentRoomStateTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(currentRoomStateSchema)
if err != nil {

View file

@ -22,10 +22,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
writer := sqlutil.NewTransactionWriter()
if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil {
return nil, err
}
currRoomState, err := NewSqliteCurrentRoomStateTable(d.db)
currRoomState, err := NewSqliteCurrentRoomStateTable(d.db, writer)
if err != nil {
return nil, err
}

View file

@ -48,10 +48,10 @@ type blacklistStatements struct {
deleteBlacklistStmt *sql.Stmt
}
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
func NewSQLiteBlacklistTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *blacklistStatements, err error) {
s = &blacklistStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err = db.Exec(blacklistSchema)
if err != nil {

View file

@ -73,10 +73,10 @@ type joinedHostsStatements struct {
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
func NewSQLiteJoinedHostsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err = db.Exec(joinedHostsSchema)
if err != nil {

View file

@ -72,10 +72,10 @@ type queueEDUsStatements struct {
selectQueueEDUServerNamesStmt *sql.Stmt
}
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
func NewSQLiteQueueEDUsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err = db.Exec(queueEDUsSchema)
if err != nil {

View file

@ -56,10 +56,10 @@ type queueJSONStatements struct {
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
func NewSQLiteQueueJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err = db.Exec(queueJSONSchema)
if err != nil {

View file

@ -81,10 +81,10 @@ type queuePDUsStatements struct {
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
func NewSQLiteQueuePDUsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err = db.Exec(queuePDUsSchema)
if err != nil {

View file

@ -50,10 +50,10 @@ type roomStatements struct {
updateRoomStmt *sql.Stmt
}
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
func NewSQLiteRoomsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *roomStatements, err error) {
s = &roomStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err = db.Exec(roomSchema)
if err != nil {

View file

@ -39,27 +39,28 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
writer := sqlutil.NewTransactionWriter()
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db, writer)
if err != nil {
return nil, err
}
rooms, err := NewSQLiteRoomsTable(d.db)
rooms, err := NewSQLiteRoomsTable(d.db, writer)
if err != nil {
return nil, err
}
queuePDUs, err := NewSQLiteQueuePDUsTable(d.db)
queuePDUs, err := NewSQLiteQueuePDUsTable(d.db, writer)
if err != nil {
return nil, err
}
queueEDUs, err := NewSQLiteQueueEDUsTable(d.db)
queueEDUs, err := NewSQLiteQueueEDUsTable(d.db, writer)
if err != nil {
return nil, err
}
queueJSON, err := NewSQLiteQueueJSONTable(d.db)
queueJSON, err := NewSQLiteQueueJSONTable(d.db, writer)
if err != nil {
return nil, err
}
blacklist, err := NewSQLiteBlacklistTable(d.db)
blacklist, err := NewSQLiteBlacklistTable(d.db, writer)
if err != nil {
return nil, err
}

View file

@ -71,10 +71,10 @@ type deviceKeysStatements struct {
deleteAllDeviceKeysStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
func NewSqliteDeviceKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(deviceKeysSchema)
if err != nil {

View file

@ -57,10 +57,10 @@ type keyChangesStatements struct {
selectKeyChangesStmt *sql.Stmt
}
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
func NewSqliteKeyChangesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {

View file

@ -68,10 +68,10 @@ type oneTimeKeysStatements struct {
deleteOneTimeKeyStmt *sql.Stmt
}
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
func NewSqliteOneTimeKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.OneTimeKeys, error) {
s := &oneTimeKeysStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(oneTimeKeysSchema)
if err != nil {

View file

@ -20,6 +20,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@ -49,13 +50,18 @@ const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
type staleDeviceListsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
}
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{}
func NewSqliteStaleDeviceListsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{
db: db,
writer: writer,
}
_, err := db.Exec(staleDeviceListsSchema)
if err != nil {
return nil, err
@ -77,8 +83,10 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
if err != nil {
return err
}
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
return err
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
return err
})
}
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {

View file

@ -25,19 +25,20 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil {
return nil, err
}
otk, err := NewSqliteOneTimeKeysTable(db)
writer := sqlutil.NewTransactionWriter()
otk, err := NewSqliteOneTimeKeysTable(db, writer)
if err != nil {
return nil, err
}
dk, err := NewSqliteDeviceKeysTable(db)
dk, err := NewSqliteDeviceKeysTable(db, writer)
if err != nil {
return nil, err
}
kc, err := NewSqliteKeyChangesTable(db)
kc, err := NewSqliteKeyChangesTable(db, writer)
if err != nil {
return nil, err
}
sdl, err := NewSqliteStaleDeviceListsTable(db)
sdl, err := NewSqliteStaleDeviceListsTable(db, writer)
if err != nil {
return nil, err
}

View file

@ -67,9 +67,9 @@ type mediaStatements struct {
selectMediaStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
func (s *mediaStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(mediaSchema)
if err != nil {

View file

@ -17,6 +17,8 @@ package sqlite3
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
type statements struct {
@ -24,11 +26,11 @@ type statements struct {
thumbnail thumbnailStatements
}
func (s *statements) prepare(db *sql.DB) (err error) {
if err = s.media.prepare(db); err != nil {
func (s *statements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
if err = s.media.prepare(db, writer); err != nil {
return
}
if err = s.thumbnail.prepare(db); err != nil {
if err = s.thumbnail.prepare(db, writer); err != nil {
return
}

View file

@ -31,6 +31,7 @@ import (
type Database struct {
statements statements
db *sql.DB
writer *sqlutil.TransactionWriter
}
// Open opens a postgres database.
@ -40,7 +41,8 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {
d.writer = sqlutil.NewTransactionWriter()
if err = d.statements.prepare(d.db, d.writer); err != nil {
return nil, err
}
return &d, nil

View file

@ -21,6 +21,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -57,12 +58,16 @@ SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method
`
type thumbnailStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertThumbnailStmt *sql.Stmt
selectThumbnailStmt *sql.Stmt
selectThumbnailsStmt *sql.Stmt
}
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
func (s *thumbnailStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(thumbnailSchema)
if err != nil {
return
@ -79,18 +84,21 @@ func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertThumbnailStmt.ExecContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
thumbnailMetadata.MediaMetadata.ContentType,
thumbnailMetadata.MediaMetadata.FileSizeBytes,
thumbnailMetadata.MediaMetadata.CreationTimestamp,
thumbnailMetadata.ThumbnailSize.Width,
thumbnailMetadata.ThumbnailSize.Height,
thumbnailMetadata.ThumbnailSize.ResizeMethod,
)
return err
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
_, err := stmt.ExecContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
thumbnailMetadata.MediaMetadata.ContentType,
thumbnailMetadata.MediaMetadata.FileSizeBytes,
thumbnailMetadata.MediaMetadata.CreationTimestamp,
thumbnailMetadata.ThumbnailSize.Width,
thumbnailMetadata.ThumbnailSize.Height,
thumbnailMetadata.ThumbnailSize.ResizeMethod,
)
return err
})
}
func (s *thumbnailStatements) selectThumbnail(

View file

@ -54,10 +54,10 @@ type eventJSONStatements struct {
bulkSelectEventJSONStmt *sql.Stmt
}
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventJSON, error) {
s := &eventJSONStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(eventJSONSchema)
if err != nil {

View file

@ -71,10 +71,10 @@ type eventStateKeyStatements struct {
bulkSelectEventStateKeyStmt *sql.Stmt
}
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventStateKeys, error) {
s := &eventStateKeyStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(eventStateKeysSchema)
if err != nil {

View file

@ -85,10 +85,10 @@ type eventTypeStatements struct {
bulkSelectEventTypeNIDStmt *sql.Stmt
}
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventTypes, error) {
s := &eventTypeStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(eventTypesSchema)
if err != nil {

View file

@ -115,10 +115,10 @@ type eventStatements struct {
selectRoomNIDForEventNIDStmt *sql.Stmt
}
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
func NewSqliteEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Events, error) {
s := &eventStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(eventsSchema)
if err != nil {

View file

@ -71,10 +71,10 @@ type inviteStatements struct {
selectInvitesAboutToRetireStmt *sql.Stmt
}
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Invites, error) {
s := &inviteStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(inviteSchema)
if err != nil {
@ -124,7 +124,7 @@ func (s *inviteStatements) UpdateInviteRetired(
if err != nil {
return err
}
defer (func() { err = rows.Close() })()
defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
for rows.Next() {
var inviteEventID string
if err = rows.Scan(&inviteEventID); err != nil {

View file

@ -88,10 +88,10 @@ type membershipStatements struct {
updateMembershipStmt *sql.Stmt
}
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
func NewSqliteMembershipTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Membership, error) {
s := &membershipStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(membershipSchema)
if err != nil {

View file

@ -59,10 +59,10 @@ type previousEventStatements struct {
selectPreviousEventExistsStmt *sql.Stmt
}
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
func NewSqlitePrevEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.PreviousEvents, error) {
s := &previousEventStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(previousEventSchema)
if err != nil {

View file

@ -51,10 +51,10 @@ type publishedStatements struct {
selectPublishedStmt *sql.Stmt
}
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Published, error) {
s := &publishedStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(publishedSchema)
if err != nil {

View file

@ -60,10 +60,10 @@ type redactionStatements struct {
markRedactionValidatedStmt *sql.Stmt
}
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Redactions, error) {
s := &redactionStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(redactionsSchema)
if err != nil {

View file

@ -65,10 +65,10 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.RoomAliases, error) {
s := &roomAliasesStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(roomAliasesSchema)
if err != nil {

View file

@ -76,10 +76,10 @@ type roomStatements struct {
selectRoomVersionForRoomNIDStmt *sql.Stmt
}
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
func NewSqliteRoomsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Rooms, error) {
s := &roomStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(roomsSchema)
if err != nil {

View file

@ -81,10 +81,10 @@ type stateBlockStatements struct {
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
}
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
func NewSqliteStateBlockTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateBlock, error) {
s := &stateBlockStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(stateDataSchema)
if err != nil {

View file

@ -55,10 +55,10 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func NewSqliteStateSnapshotTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(stateSnapshotSchema)
if err != nil {

View file

@ -51,6 +51,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
writer := sqlutil.NewTransactionWriter()
//d.db.Exec("PRAGMA journal_mode=WAL;")
//d.db.Exec("PRAGMA read_uncommitted = true;")
@ -60,59 +61,59 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
// which it will never obtain.
d.db.SetMaxOpenConns(20)
d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db)
d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db, writer)
if err != nil {
return nil, err
}
d.eventTypes, err = NewSqliteEventTypesTable(d.db)
d.eventTypes, err = NewSqliteEventTypesTable(d.db, writer)
if err != nil {
return nil, err
}
d.eventJSON, err = NewSqliteEventJSONTable(d.db)
d.eventJSON, err = NewSqliteEventJSONTable(d.db, writer)
if err != nil {
return nil, err
}
d.events, err = NewSqliteEventsTable(d.db)
d.events, err = NewSqliteEventsTable(d.db, writer)
if err != nil {
return nil, err
}
d.rooms, err = NewSqliteRoomsTable(d.db)
d.rooms, err = NewSqliteRoomsTable(d.db, writer)
if err != nil {
return nil, err
}
d.transactions, err = NewSqliteTransactionsTable(d.db)
d.transactions, err = NewSqliteTransactionsTable(d.db, writer)
if err != nil {
return nil, err
}
stateBlock, err := NewSqliteStateBlockTable(d.db)
stateBlock, err := NewSqliteStateBlockTable(d.db, writer)
if err != nil {
return nil, err
}
stateSnapshot, err := NewSqliteStateSnapshotTable(d.db)
stateSnapshot, err := NewSqliteStateSnapshotTable(d.db, writer)
if err != nil {
return nil, err
}
d.prevEvents, err = NewSqlitePrevEventsTable(d.db)
d.prevEvents, err = NewSqlitePrevEventsTable(d.db, writer)
if err != nil {
return nil, err
}
roomAliases, err := NewSqliteRoomAliasesTable(d.db)
roomAliases, err := NewSqliteRoomAliasesTable(d.db, writer)
if err != nil {
return nil, err
}
d.invites, err = NewSqliteInvitesTable(d.db)
d.invites, err = NewSqliteInvitesTable(d.db, writer)
if err != nil {
return nil, err
}
d.membership, err = NewSqliteMembershipTable(d.db)
d.membership, err = NewSqliteMembershipTable(d.db, writer)
if err != nil {
return nil, err
}
published, err := NewSqlitePublishedTable(d.db)
published, err := NewSqlitePublishedTable(d.db, writer)
if err != nil {
return nil, err
}
redactions, err := NewSqliteRedactionsTable(d.db)
redactions, err := NewSqliteRedactionsTable(d.db, writer)
if err != nil {
return nil, err
}

View file

@ -50,10 +50,10 @@ type transactionStatements struct {
selectTransactionEventIDStmt *sql.Stmt
}
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
func NewSqliteTransactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Transactions, error) {
s := &transactionStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(transactionsSchema)
if err != nil {

View file

@ -17,6 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
"golang.org/x/crypto/ed25519"
@ -30,6 +31,8 @@ import (
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers.
type Database struct {
db *sql.DB
writer *sqlutil.TransactionWriter
statements serverKeyStatements
}
@ -47,8 +50,11 @@ func NewDatabase(
if err != nil {
return nil, err
}
d := &Database{}
err = d.statements.prepare(db)
d := &Database{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
err = d.statements.prepare(d.db, d.writer)
if err != nil {
return nil, err
}

View file

@ -68,9 +68,9 @@ type serverKeyStatements struct {
upsertServerKeysStmt *sql.Stmt
}
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
func (s *serverKeyStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(serverKeysSchema)
if err != nil {
return

View file

@ -58,10 +58,10 @@ type accountDataStatements struct {
selectAccountDataInRangeStmt *sql.Stmt
}
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
func NewSqliteAccountDataTable(db *sql.DB, writer *sqlutil.TransactionWriter, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
streamIDStatements: streamID,
}
_, err := db.Exec(accountDataSchema)

View file

@ -55,10 +55,10 @@ type backwardExtremitiesStatements struct {
deleteBackwardExtremityStmt *sql.Stmt
}
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
func NewSqliteBackwardsExtremitiesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(backwardExtremitiesSchema)
if err != nil {

View file

@ -95,10 +95,10 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt
}
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
func NewSqliteCurrentRoomStateTable(db *sql.DB, writer *sqlutil.TransactionWriter, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
streamIDStatements: streamID,
}
_, err := db.Exec(currentRoomStateSchema)

View file

@ -58,14 +58,14 @@ type filterStatements struct {
insertFilterStmt *sql.Stmt
}
func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
func NewSqliteFilterTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Filter, error) {
_, err := db.Exec(filterSchema)
if err != nil {
return nil, err
}
s := &filterStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err

View file

@ -67,10 +67,10 @@ type inviteEventsStatements struct {
selectMaxInviteIDStmt *sql.Stmt
}
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
streamIDStatements: streamID,
}
_, err := db.Exec(inviteEventsSchema)

View file

@ -117,10 +117,10 @@ type outputRoomEventsStatements struct {
updateEventJSONStmt *sql.Stmt
}
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
func NewSqliteEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
streamIDStatements: streamID,
}
_, err := db.Exec(outputRoomEventsSchema)

View file

@ -75,10 +75,10 @@ type outputRoomEventsTopologyStatements struct {
selectMaxPositionInTopologyStmt *sql.Stmt
}
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
func NewSqliteTopologyTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(outputRoomEventsTopologySchema)
if err != nil {

View file

@ -79,10 +79,10 @@ type sendToDeviceStatements struct {
countSendToDeviceMessagesStmt *sql.Stmt
}
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func NewSqliteSendToDeviceTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
writer: writer,
}
_, err := db.Exec(sendToDeviceSchema)
if err != nil {

View file

@ -33,9 +33,9 @@ type streamIDStatements struct {
selectStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *streamIDStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(streamIDTableSchema)
if err != nil {
return

View file

@ -31,7 +31,8 @@ import (
// both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
shared.Database
db *sql.DB
db *sql.DB
writer *sqlutil.TransactionWriter
sqlutil.PartitionOffsetStatements
streamID streamIDStatements
}
@ -44,6 +45,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
d.writer = sqlutil.NewTransactionWriter()
if err = d.prepare(); err != nil {
return nil, err
}
@ -54,38 +56,38 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
return err
}
if err = d.streamID.prepare(d.db); err != nil {
if err = d.streamID.prepare(d.db, d.writer); err != nil {
return err
}
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
accountData, err := NewSqliteAccountDataTable(d.db, d.writer, &d.streamID)
if err != nil {
return err
}
events, err := NewSqliteEventsTable(d.db, &d.streamID)
events, err := NewSqliteEventsTable(d.db, d.writer, &d.streamID)
if err != nil {
return err
}
roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID)
roomState, err := NewSqliteCurrentRoomStateTable(d.db, d.writer, &d.streamID)
if err != nil {
return err
}
invites, err := NewSqliteInvitesTable(d.db, &d.streamID)
invites, err := NewSqliteInvitesTable(d.db, d.writer, &d.streamID)
if err != nil {
return err
}
topology, err := NewSqliteTopologyTable(d.db)
topology, err := NewSqliteTopologyTable(d.db, d.writer)
if err != nil {
return err
}
bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db)
bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db, d.writer)
if err != nil {
return err
}
sendToDevice, err := NewSqliteSendToDeviceTable(d.db)
sendToDevice, err := NewSqliteSendToDeviceTable(d.db, d.writer)
if err != nil {
return err
}
filter, err := NewSqliteFilterTable(d.db)
filter, err := NewSqliteFilterTable(d.db, d.writer)
if err != nil {
return err
}

View file

@ -57,9 +57,9 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(accountDataSchema)
if err != nil {
return

View file

@ -67,9 +67,9 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
func (s *accountsStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(accountsSchema)
if err != nil {
return

View file

@ -61,9 +61,9 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(profilesSchema)
if err != nil {
return

View file

@ -57,20 +57,24 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
writer := sqlutil.NewTransactionWriter()
a := accountsStatements{
db: db,
writer: writer,
}
if err = a.prepare(db, writer, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
if err = p.prepare(db, writer); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
if err = ac.prepare(db, writer); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
if err = t.prepare(db, writer); err != nil {
return nil, err
}
return &Database{

View file

@ -61,9 +61,9 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
func (s *threepidStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(threepidSchema)
if err != nil {
return

View file

@ -91,9 +91,9 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
func (s *devicesStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(devicesSchema)
if err != nil {
return

View file

@ -34,6 +34,7 @@ var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
writer *sqlutil.TransactionWriter
devices devicesStatements
}
@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
writer := sqlutil.NewTransactionWriter()
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
return &Database{db, writer, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.