Undo non-roomserver changes for now

This commit is contained in:
Neil Alexander 2020-08-18 09:54:51 +01:00
parent 5c1183fba4
commit c9ee7cc269
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
41 changed files with 135 additions and 164 deletions

View file

@ -75,9 +75,9 @@ type eventsStatements struct {
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
} }
func (s *eventsStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *eventsStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(appserviceEventsSchema) _, err = db.Exec(appserviceEventsSchema)
if err != nil { if err != nil {
return return

View file

@ -41,8 +41,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if result.db, err = sqlutil.Open(dbProperties); err != nil { if result.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
writer := sqlutil.NewTransactionWriter() if err = result.prepare(); err != nil {
if err = result.prepare(writer); err != nil {
return nil, err return nil, err
} }
if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil { if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil {
@ -51,12 +50,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
return &result, nil return &result, nil
} }
func (d *Database) prepare(writer *sqlutil.TransactionWriter) error { func (d *Database) prepare() error {
if err := d.events.prepare(d.db, writer); err != nil { if err := d.events.prepare(d.db); err != nil {
return err return err
} }
return d.txnID.prepare(d.db, writer) return d.txnID.prepare(d.db)
} }
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database // StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database

View file

@ -42,9 +42,9 @@ type txnStatements struct {
selectTxnIDStmt *sql.Stmt selectTxnIDStmt *sql.Stmt
} }
func (s *txnStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *txnStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(txnIDSchema) _, err = db.Exec(txnIDSchema)
if err != nil { if err != nil {
return return

View file

@ -93,10 +93,10 @@ type currentRoomStateStatements struct {
selectKnownUsersStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{ s := &currentRoomStateStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(currentRoomStateSchema) _, err := db.Exec(currentRoomStateSchema)
if err != nil { if err != nil {

View file

@ -22,11 +22,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
writer := sqlutil.NewTransactionWriter()
if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil {
return nil, err return nil, err
} }
currRoomState, err := NewSqliteCurrentRoomStateTable(d.db, writer) currRoomState, err := NewSqliteCurrentRoomStateTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -48,10 +48,10 @@ type blacklistStatements struct {
deleteBlacklistStmt *sql.Stmt deleteBlacklistStmt *sql.Stmt
} }
func NewSQLiteBlacklistTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *blacklistStatements, err error) { func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
s = &blacklistStatements{ s = &blacklistStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(blacklistSchema) _, err = db.Exec(blacklistSchema)
if err != nil { if err != nil {

View file

@ -73,10 +73,10 @@ type joinedHostsStatements struct {
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteJoinedHostsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *joinedHostsStatements, err error) { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{ s = &joinedHostsStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(joinedHostsSchema) _, err = db.Exec(joinedHostsSchema)
if err != nil { if err != nil {

View file

@ -72,10 +72,10 @@ type queueEDUsStatements struct {
selectQueueEDUServerNamesStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt
} }
func NewSQLiteQueueEDUsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *queueEDUsStatements, err error) { func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{ s = &queueEDUsStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(queueEDUsSchema) _, err = db.Exec(queueEDUsSchema)
if err != nil { if err != nil {

View file

@ -56,10 +56,10 @@ type queueJSONStatements struct {
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteQueueJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *queueJSONStatements, err error) { func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{ s = &queueJSONStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(queueJSONSchema) _, err = db.Exec(queueJSONSchema)
if err != nil { if err != nil {

View file

@ -81,10 +81,10 @@ type queuePDUsStatements struct {
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteQueuePDUsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *queuePDUsStatements, err error) { func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{ s = &queuePDUsStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(queuePDUsSchema) _, err = db.Exec(queuePDUsSchema)
if err != nil { if err != nil {

View file

@ -50,10 +50,10 @@ type roomStatements struct {
updateRoomStmt *sql.Stmt updateRoomStmt *sql.Stmt
} }
func NewSQLiteRoomsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (s *roomStatements, err error) { func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
s = &roomStatements{ s = &roomStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(roomSchema) _, err = db.Exec(roomSchema)
if err != nil { if err != nil {

View file

@ -39,28 +39,27 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
writer := sqlutil.NewTransactionWriter() joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db, writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rooms, err := NewSQLiteRoomsTable(d.db, writer) rooms, err := NewSQLiteRoomsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
queuePDUs, err := NewSQLiteQueuePDUsTable(d.db, writer) queuePDUs, err := NewSQLiteQueuePDUsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
queueEDUs, err := NewSQLiteQueueEDUsTable(d.db, writer) queueEDUs, err := NewSQLiteQueueEDUsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
queueJSON, err := NewSQLiteQueueJSONTable(d.db, writer) queueJSON, err := NewSQLiteQueueJSONTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db, writer) blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -71,10 +71,10 @@ type deviceKeysStatements struct {
deleteAllDeviceKeysStmt *sql.Stmt deleteAllDeviceKeysStmt *sql.Stmt
} }
func NewSqliteDeviceKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.DeviceKeys, error) { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{ s := &deviceKeysStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(deviceKeysSchema) _, err := db.Exec(deviceKeysSchema)
if err != nil { if err != nil {

View file

@ -57,10 +57,10 @@ type keyChangesStatements struct {
selectKeyChangesStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt
} }
func NewSqliteKeyChangesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.KeyChanges, error) { func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{ s := &keyChangesStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(keyChangesSchema) _, err := db.Exec(keyChangesSchema)
if err != nil { if err != nil {

View file

@ -68,10 +68,10 @@ type oneTimeKeysStatements struct {
deleteOneTimeKeyStmt *sql.Stmt deleteOneTimeKeyStmt *sql.Stmt
} }
func NewSqliteOneTimeKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.OneTimeKeys, error) { func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
s := &oneTimeKeysStatements{ s := &oneTimeKeysStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(oneTimeKeysSchema) _, err := db.Exec(oneTimeKeysSchema)
if err != nil { if err != nil {

View file

@ -20,7 +20,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -50,18 +49,13 @@ const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
type staleDeviceListsStatements struct { type staleDeviceListsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
upsertStaleDeviceListStmt *sql.Stmt upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt
} }
func NewSqliteStaleDeviceListsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StaleDeviceLists, error) { func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{ s := &staleDeviceListsStatements{}
db: db,
writer: writer,
}
_, err := db.Exec(staleDeviceListsSchema) _, err := db.Exec(staleDeviceListsSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -83,10 +77,8 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
if err != nil { if err != nil {
return err return err
} }
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) return err
return err
})
} }
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {

View file

@ -25,20 +25,19 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
writer := sqlutil.NewTransactionWriter() otk, err := NewSqliteOneTimeKeysTable(db)
otk, err := NewSqliteOneTimeKeysTable(db, writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dk, err := NewSqliteDeviceKeysTable(db, writer) dk, err := NewSqliteDeviceKeysTable(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
kc, err := NewSqliteKeyChangesTable(db, writer) kc, err := NewSqliteKeyChangesTable(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sdl, err := NewSqliteStaleDeviceListsTable(db, writer) sdl, err := NewSqliteStaleDeviceListsTable(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -67,9 +67,9 @@ type mediaStatements struct {
selectMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt
} }
func (s *mediaStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *mediaStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(mediaSchema) _, err = db.Exec(mediaSchema)
if err != nil { if err != nil {

View file

@ -17,8 +17,6 @@ package sqlite3
import ( import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
type statements struct { type statements struct {
@ -26,11 +24,11 @@ type statements struct {
thumbnail thumbnailStatements thumbnail thumbnailStatements
} }
func (s *statements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *statements) prepare(db *sql.DB) (err error) {
if err = s.media.prepare(db, writer); err != nil { if err = s.media.prepare(db); err != nil {
return return
} }
if err = s.thumbnail.prepare(db, writer); err != nil { if err = s.thumbnail.prepare(db); err != nil {
return return
} }

View file

@ -31,7 +31,6 @@ import (
type Database struct { type Database struct {
statements statements statements statements
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
} }
// Open opens a postgres database. // Open opens a postgres database.
@ -41,8 +40,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
d.writer = sqlutil.NewTransactionWriter() if err = d.statements.prepare(d.db); err != nil {
if err = d.statements.prepare(d.db, d.writer); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil

View file

@ -21,7 +21,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -58,16 +57,12 @@ SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method
` `
type thumbnailStatements struct { type thumbnailStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertThumbnailStmt *sql.Stmt insertThumbnailStmt *sql.Stmt
selectThumbnailStmt *sql.Stmt selectThumbnailStmt *sql.Stmt
selectThumbnailsStmt *sql.Stmt selectThumbnailsStmt *sql.Stmt
} }
func (s *thumbnailStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(thumbnailSchema) _, err = db.Exec(thumbnailSchema)
if err != nil { if err != nil {
return return
@ -84,21 +79,18 @@ func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error { ) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { _, err := s.insertThumbnailStmt.ExecContext(
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt) ctx,
_, err := stmt.ExecContext( thumbnailMetadata.MediaMetadata.MediaID,
ctx, thumbnailMetadata.MediaMetadata.Origin,
thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.ContentType,
thumbnailMetadata.MediaMetadata.Origin, thumbnailMetadata.MediaMetadata.FileSizeBytes,
thumbnailMetadata.MediaMetadata.ContentType, thumbnailMetadata.MediaMetadata.CreationTimestamp,
thumbnailMetadata.MediaMetadata.FileSizeBytes, thumbnailMetadata.ThumbnailSize.Width,
thumbnailMetadata.MediaMetadata.CreationTimestamp, thumbnailMetadata.ThumbnailSize.Height,
thumbnailMetadata.ThumbnailSize.Width, thumbnailMetadata.ThumbnailSize.ResizeMethod,
thumbnailMetadata.ThumbnailSize.Height, )
thumbnailMetadata.ThumbnailSize.ResizeMethod, return err
)
return err
})
} }
func (s *thumbnailStatements) selectThumbnail( func (s *thumbnailStatements) selectThumbnail(

View file

@ -17,7 +17,6 @@ package sqlite3
import ( import (
"context" "context"
"database/sql"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@ -31,8 +30,6 @@ import (
// A Database implements gomatrixserverlib.KeyDatabase and is used to store // A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers. // the public keys for other matrix servers.
type Database struct { type Database struct {
db *sql.DB
writer *sqlutil.TransactionWriter
statements serverKeyStatements statements serverKeyStatements
} }
@ -50,11 +47,8 @@ func NewDatabase(
if err != nil { if err != nil {
return nil, err return nil, err
} }
d := &Database{ d := &Database{}
db: db, err = d.statements.prepare(db)
writer: sqlutil.NewTransactionWriter(),
}
err = d.statements.prepare(d.db, d.writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -68,9 +68,9 @@ type serverKeyStatements struct {
upsertServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt
} }
func (s *serverKeyStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(serverKeysSchema) _, err = db.Exec(serverKeysSchema)
if err != nil { if err != nil {
return return

View file

@ -171,7 +171,11 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition
func (d *Database) AddInviteEvent( func (d *Database) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
return d.Invites.InsertInviteEvent(ctx, nil, inviteEvent) err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
return err
})
return
} }
// RetireInviteEvent removes an old invite event from the database. // RetireInviteEvent removes an old invite event from the database.

View file

@ -58,10 +58,10 @@ type accountDataStatements struct {
selectAccountDataInRangeStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt
} }
func NewSqliteAccountDataTable(db *sql.DB, writer *sqlutil.TransactionWriter, streamID *streamIDStatements) (tables.AccountData, error) { func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{ s := &accountDataStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(accountDataSchema) _, err := db.Exec(accountDataSchema)

View file

@ -55,10 +55,10 @@ type backwardExtremitiesStatements struct {
deleteBackwardExtremityStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt
} }
func NewSqliteBackwardsExtremitiesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.BackwardsExtremities, error) { func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{ s := &backwardExtremitiesStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(backwardExtremitiesSchema) _, err := db.Exec(backwardExtremitiesSchema)
if err != nil { if err != nil {

View file

@ -95,10 +95,10 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, writer *sqlutil.TransactionWriter, streamID *streamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{ s := &currentRoomStateStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(currentRoomStateSchema) _, err := db.Exec(currentRoomStateSchema)

View file

@ -58,14 +58,14 @@ type filterStatements struct {
insertFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt
} }
func NewSqliteFilterTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Filter, error) { func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
_, err := db.Exec(filterSchema) _, err := db.Exec(filterSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s := &filterStatements{ s := &filterStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err return nil, err

View file

@ -67,10 +67,10 @@ type inviteEventsStatements struct {
selectMaxInviteIDStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt
} }
func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter, streamID *streamIDStatements) (tables.Invites, error) { func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{ s := &inviteEventsStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(inviteEventsSchema) _, err := db.Exec(inviteEventsSchema)
@ -95,12 +95,13 @@ func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter, stream
func (s *inviteEventsStatements) InsertInviteEvent( func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) { ) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return
}
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
var err error
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return err
}
var headeredJSON []byte var headeredJSON []byte
headeredJSON, err = json.Marshal(inviteEvent) headeredJSON, err = json.Marshal(inviteEvent)
if err != nil { if err != nil {
@ -123,11 +124,13 @@ func (s *inviteEventsStatements) InsertInviteEvent(
func (s *inviteEventsStatements) DeleteInviteEvent( func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil) var streamPos types.StreamPosition
if err != nil { err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return streamPos, err var err error
} streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { if err != nil {
return err
}
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
return err return err
}) })

View file

@ -117,10 +117,10 @@ type outputRoomEventsStatements struct {
updateEventJSONStmt *sql.Stmt updateEventJSONStmt *sql.Stmt
} }
func NewSqliteEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter, streamID *streamIDStatements) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{ s := &outputRoomEventsStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(outputRoomEventsSchema) _, err := db.Exec(outputRoomEventsSchema)
@ -304,11 +304,13 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err return 0, err
} }
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn) var streamPos types.StreamPosition
if err != nil {
return 0, err
}
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return err
}
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
_, ierr := insertStmt.ExecContext( _, ierr := insertStmt.ExecContext(
ctx, ctx,

View file

@ -75,10 +75,10 @@ type outputRoomEventsTopologyStatements struct {
selectMaxPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt
} }
func NewSqliteTopologyTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Topology, error) { func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{ s := &outputRoomEventsTopologyStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(outputRoomEventsTopologySchema) _, err := db.Exec(outputRoomEventsTopologySchema)
if err != nil { if err != nil {

View file

@ -79,10 +79,10 @@ type sendToDeviceStatements struct {
countSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt
} }
func NewSqliteSendToDeviceTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.SendToDevice, error) { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{ s := &sendToDeviceStatements{
db: db, db: db,
writer: writer, writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(sendToDeviceSchema) _, err := db.Exec(sendToDeviceSchema)
if err != nil { if err != nil {

View file

@ -33,9 +33,9 @@ type streamIDStatements struct {
selectStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt
} }
func (s *streamIDStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(streamIDTableSchema) _, err = db.Exec(streamIDTableSchema)
if err != nil { if err != nil {
return return

View file

@ -31,8 +31,7 @@ import (
// both the database for PDUs and caches for EDUs. // both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct { type SyncServerDatasource struct {
shared.Database shared.Database
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
streamID streamIDStatements streamID streamIDStatements
} }
@ -45,7 +44,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
d.writer = sqlutil.NewTransactionWriter()
if err = d.prepare(); err != nil { if err = d.prepare(); err != nil {
return nil, err return nil, err
} }
@ -56,38 +54,38 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
return err return err
} }
if err = d.streamID.prepare(d.db, d.writer); err != nil { if err = d.streamID.prepare(d.db); err != nil {
return err return err
} }
accountData, err := NewSqliteAccountDataTable(d.db, d.writer, &d.streamID) accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }
events, err := NewSqliteEventsTable(d.db, d.writer, &d.streamID) events, err := NewSqliteEventsTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }
roomState, err := NewSqliteCurrentRoomStateTable(d.db, d.writer, &d.streamID) roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }
invites, err := NewSqliteInvitesTable(d.db, d.writer, &d.streamID) invites, err := NewSqliteInvitesTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }
topology, err := NewSqliteTopologyTable(d.db, d.writer) topology, err := NewSqliteTopologyTable(d.db)
if err != nil { if err != nil {
return err return err
} }
bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db, d.writer) bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db)
if err != nil { if err != nil {
return err return err
} }
sendToDevice, err := NewSqliteSendToDeviceTable(d.db, d.writer) sendToDevice, err := NewSqliteSendToDeviceTable(d.db)
if err != nil { if err != nil {
return err return err
} }
filter, err := NewSqliteFilterTable(d.db, d.writer) filter, err := NewSqliteFilterTable(d.db)
if err != nil { if err != nil {
return err return err
} }

View file

@ -57,9 +57,9 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(accountDataSchema) _, err = db.Exec(accountDataSchema)
if err != nil { if err != nil {
return return

View file

@ -67,9 +67,9 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) { func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(accountsSchema) _, err = db.Exec(accountsSchema)
if err != nil { if err != nil {
return return

View file

@ -61,9 +61,9 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *profilesStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(profilesSchema) _, err = db.Exec(profilesSchema)
if err != nil { if err != nil {
return return

View file

@ -57,24 +57,20 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = partitions.Prepare(db, "account"); err != nil { if err = partitions.Prepare(db, "account"); err != nil {
return nil, err return nil, err
} }
writer := sqlutil.NewTransactionWriter() a := accountsStatements{}
a := accountsStatements{ if err = a.prepare(db, serverName); err != nil {
db: db,
writer: writer,
}
if err = a.prepare(db, writer, serverName); err != nil {
return nil, err return nil, err
} }
p := profilesStatements{} p := profilesStatements{}
if err = p.prepare(db, writer); err != nil { if err = p.prepare(db); err != nil {
return nil, err return nil, err
} }
ac := accountDataStatements{} ac := accountDataStatements{}
if err = ac.prepare(db, writer); err != nil { if err = ac.prepare(db); err != nil {
return nil, err return nil, err
} }
t := threepidStatements{} t := threepidStatements{}
if err = t.prepare(db, writer); err != nil { if err = t.prepare(db); err != nil {
return nil, err return nil, err
} }
return &Database{ return &Database{

View file

@ -61,9 +61,9 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt
} }
func (s *threepidStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter) (err error) { func (s *threepidStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(threepidSchema) _, err = db.Exec(threepidSchema)
if err != nil { if err != nil {
return return

View file

@ -91,9 +91,9 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) prepare(db *sql.DB, writer *sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) { func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(devicesSchema) _, err = db.Exec(devicesSchema)
if err != nil { if err != nil {
return return

View file

@ -34,7 +34,6 @@ var deviceIDByteLength = 6
// Database represents a device database. // Database represents a device database.
type Database struct { type Database struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
devices devicesStatements devices devicesStatements
} }
@ -44,12 +43,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil { if err != nil {
return nil, err return nil, err
} }
writer := sqlutil.NewTransactionWriter()
d := devicesStatements{} d := devicesStatements{}
if err = d.prepare(db, writer, serverName); err != nil { if err = d.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
return &Database{db, writer, d}, nil return &Database{db, d}, nil
} }
// GetDeviceByAccessToken returns the device matching the given access token. // GetDeviceByAccessToken returns the device matching the given access token.