Move writers up a layer in sync API

This commit is contained in:
Neil Alexander 2020-08-21 09:55:46 +01:00
parent bc1023ea19
commit e708cb73aa
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
13 changed files with 144 additions and 195 deletions

View file

@ -68,7 +68,7 @@ func (w *ExclusiveWriter) run() {
return task.f(txn) return task.f(txn)
}) })
} else { } else {
panic("expected database or transaction but got neither") task.wait <- task.f(nil)
} }
close(task.wait) close(task.wait)
} }

View file

@ -80,6 +80,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: sqlutil.NewDummyWriter(),
Invites: invites, Invites: invites,
AccountData: accountData, AccountData: accountData,
OutputEvents: events, OutputEvents: events,
@ -88,7 +89,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
BackwardExtremities: backwardExtremities, BackwardExtremities: backwardExtremities,
Filter: filter, Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
SendToDeviceWriter: sqlutil.NewExclusiveWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),
} }
return &d, nil return &d, nil

View file

@ -37,6 +37,7 @@ import (
// For now this contains the shared functions // For now this contains the shared functions
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer
Invites tables.Invites Invites tables.Invites
AccountData tables.AccountData AccountData tables.AccountData
OutputEvents tables.Events OutputEvents tables.Events
@ -45,7 +46,6 @@ type Database struct {
BackwardExtremities tables.BackwardsExtremities BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice SendToDevice tables.SendToDevice
Filter tables.Filter Filter tables.Filter
SendToDeviceWriter sqlutil.Writer
EDUCache *cache.EDUCache EDUCache *cache.EDUCache
} }
@ -129,10 +129,7 @@ func (d *Database) GetStateEvent(
func (d *Database) GetStateEventsForRoom( func (d *Database) GetStateEventsForRoom(
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { ) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) {
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter)
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter)
return err
})
return return
} }
@ -171,15 +168,23 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition
func (d *Database) AddInviteEvent( func (d *Database) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
return d.Invites.InsertInviteEvent(ctx, nil, inviteEvent) _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
sp, err = d.Invites.InsertInviteEvent(ctx, nil, inviteEvent)
return nil
})
return
} }
// RetireInviteEvent removes an old invite event from the database. // RetireInviteEvent removes an old invite event from the database.
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *Database) RetireInviteEvent( func (d *Database) RetireInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) (types.StreamPosition, error) { ) (sp types.StreamPosition, err error) {
return d.Invites.DeleteInviteEvent(ctx, inviteEventID) _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
sp, err = d.Invites.DeleteInviteEvent(ctx, inviteEventID)
return nil
})
return
} }
// GetAccountDataInRange returns all account data for a given user inserted or // GetAccountDataInRange returns all account data for a given user inserted or
@ -203,7 +208,7 @@ func (d *Database) GetAccountDataInRange(
func (d *Database) UpsertAccountData( func (d *Database) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string, ctx context.Context, userID, roomID, dataType string,
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
return err return err
}) })
@ -233,6 +238,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
return err return err
@ -271,7 +277,7 @@ func (d *Database) WriteEvent(
addStateEventIDs, removeStateEventIDs []string, addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, excludeFromSync bool, transactionID *api.TransactionID, excludeFromSync bool,
) (pduPosition types.StreamPosition, returnErr error) { ) (pduPosition types.StreamPosition, returnErr error) {
returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error var err error
pos, err := d.OutputEvents.InsertEvent( pos, err := d.OutputEvents.InsertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
@ -300,6 +306,7 @@ func (d *Database) WriteEvent(
return pduPosition, returnErr return pduPosition, returnErr
} }
// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) updateRoomState( func (d *Database) updateRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
removedEventIDs []string, removedEventIDs []string,
@ -1110,7 +1117,7 @@ func (d *Database) StoreNewSendForDeviceMessage(
} }
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee // Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place. // that we don't lock the table for writes in more than one place.
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AddSendToDeviceEvent( return d.AddSendToDeviceEvent(
ctx, txn, userID, deviceID, string(j), ctx, txn, userID, deviceID, string(j),
) )
@ -1175,7 +1182,7 @@ func (d *Database) CleanSendToDeviceUpdates(
// If we need to write to the database then we'll ask the SendToDeviceWriter to // If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in // do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place. // more than one place.
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion. // Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)

View file

@ -20,7 +20,6 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -51,17 +50,15 @@ const selectMaxAccountDataIDSQL = "" +
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt
selectAccountDataInRangeStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt
} }
func NewSqliteAccountDataTable(db *sql.DB, writer sqlutil.Writer, streamID *streamIDStatements) (tables.AccountData, error) { func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{ s := &accountDataStatements{
db: db, db: db,
writer: writer,
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(accountDataSchema) _, err := db.Exec(accountDataSchema)
@ -84,15 +81,12 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string, userID, roomID, dataType string,
) (pos types.StreamPosition, err error) { ) (pos types.StreamPosition, err error) {
return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
var err error
pos, err = s.streamIDStatements.nextStreamID(ctx, txn) pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil { if err != nil {
return err return
} }
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
return err return
})
} }
func (s *accountDataStatements) SelectAccountDataInRange( func (s *accountDataStatements) SelectAccountDataInRange(

View file

@ -19,7 +19,6 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
) )
@ -49,16 +48,14 @@ const deleteBackwardExtremitySQL = "" +
type backwardExtremitiesStatements struct { type backwardExtremitiesStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
insertBackwardExtremityStmt *sql.Stmt insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt
} }
func NewSqliteBackwardsExtremitiesTable(db *sql.DB, writer sqlutil.Writer) (tables.BackwardsExtremities, error) { func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{ s := &backwardExtremitiesStatements{
db: db, db: db,
writer: writer,
} }
_, err := db.Exec(backwardExtremitiesSchema) _, err := db.Exec(backwardExtremitiesSchema)
if err != nil { if err != nil {
@ -79,10 +76,8 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB, writer sqlutil.Writer) (tabl
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
_, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return err return err
})
} }
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
@ -110,8 +105,6 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
_, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err return err
})
} }

View file

@ -85,7 +85,6 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
@ -95,10 +94,9 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, writer sqlutil.Writer, streamID *streamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{ s := &currentRoomStateStatements{
db: db, db: db,
writer: writer,
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(currentRoomStateSchema) _, err := db.Exec(currentRoomStateSchema)
@ -200,11 +198,9 @@ func (s *currentRoomStateStatements) SelectCurrentState(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID( func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
_, err := stmt.ExecContext(ctx, eventID) _, err := stmt.ExecContext(ctx, eventID)
return err return err
})
} }
func (s *currentRoomStateStatements) UpsertRoomState( func (s *currentRoomStateStatements) UpsertRoomState(
@ -225,9 +221,8 @@ func (s *currentRoomStateStatements) UpsertRoomState(
} }
// upsert state event // upsert state event
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
_, err := stmt.ExecContext( _, err = stmt.ExecContext(
ctx, ctx,
event.RoomID(), event.RoomID(),
event.EventID(), event.EventID(),
@ -240,7 +235,6 @@ func (s *currentRoomStateStatements) UpsertRoomState(
addedAt, addedAt,
) )
return err return err
})
} }
func minOfInts(a, b int) int { func minOfInts(a, b int) int {

View file

@ -20,7 +20,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -52,20 +51,18 @@ const insertFilterSQL = "" +
type filterStatements struct { type filterStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
selectFilterStmt *sql.Stmt selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt
} }
func NewSqliteFilterTable(db *sql.DB, writer sqlutil.Writer) (tables.Filter, error) { func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
_, err := db.Exec(filterSchema) _, err := db.Exec(filterSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s := &filterStatements{ s := &filterStatements{
db: db, db: db,
writer: writer,
} }
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err return nil, err
@ -114,7 +111,6 @@ func (s *filterStatements) InsertFilter(
return "", err return "", err
} }
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
// Check if filter already exists in the database using its localpart and content // Check if filter already exists in the database using its localpart and content
// //
// This can result in a race condition when two clients try to insert the // This can result in a race condition when two clients try to insert the
@ -123,24 +119,22 @@ func (s *filterStatements) InsertFilter(
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID) localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return "", err
} }
// If it does, return the existing ID // If it does, return the existing ID
if existingFilterID != "" { if existingFilterID != "" {
return nil return existingFilterID, nil
} }
// Otherwise insert the filter and return the new ID // Otherwise insert the filter and return the new ID
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
if err != nil { if err != nil {
return err return "", err
} }
rowid, err := res.LastInsertId() rowid, err := res.LastInsertId()
if err != nil { if err != nil {
return err return "", err
} }
filterID = fmt.Sprintf("%d", rowid) filterID = fmt.Sprintf("%d", rowid)
return nil
})
return return
} }

View file

@ -59,7 +59,6 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct { type inviteEventsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertInviteEventStmt *sql.Stmt insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt
@ -67,10 +66,9 @@ type inviteEventsStatements struct {
selectMaxInviteIDStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt
} }
func NewSqliteInvitesTable(db *sql.DB, writer sqlutil.Writer, streamID *streamIDStatements) (tables.Invites, error) { func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{ s := &inviteEventsStatements{
db: db, db: db,
writer: writer,
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(inviteEventsSchema) _, err := db.Exec(inviteEventsSchema)
@ -100,14 +98,14 @@ func (s *inviteEventsStatements) InsertInviteEvent(
return return
} }
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
var headeredJSON []byte var headeredJSON []byte
headeredJSON, err = json.Marshal(inviteEvent) headeredJSON, err = json.Marshal(inviteEvent)
if err != nil { if err != nil {
return err return
} }
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
_, err = stmt.ExecContext(
ctx, ctx,
streamPos, streamPos,
inviteEvent.RoomID(), inviteEvent.RoomID(),
@ -115,8 +113,6 @@ func (s *inviteEventsStatements) InsertInviteEvent(
*inviteEvent.StateKey(), *inviteEvent.StateKey(),
headeredJSON, headeredJSON,
) )
return err
})
return return
} }
@ -127,10 +123,7 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
if err != nil { if err != nil {
return streamPos, err return streamPos, err
} }
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
return err
})
return streamPos, err return streamPos, err
} }

View file

@ -105,7 +105,6 @@ const selectStateInRangeSQL = "" +
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
@ -117,10 +116,9 @@ type outputRoomEventsStatements struct {
updateEventJSONStmt *sql.Stmt updateEventJSONStmt *sql.Stmt
} }
func NewSqliteEventsTable(db *sql.DB, writer sqlutil.Writer, streamID *streamIDStatements) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{ s := &outputRoomEventsStatements{
db: db, db: db,
writer: writer,
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(outputRoomEventsSchema) _, err := db.Exec(outputRoomEventsSchema)
@ -159,10 +157,8 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
if err != nil { if err != nil {
return err return err
} }
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
return err return err
})
} }
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
@ -308,9 +304,8 @@ func (s *outputRoomEventsStatements) InsertEvent(
if err != nil { if err != nil {
return 0, err return 0, err
} }
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
_, ierr := insertStmt.ExecContext( _, err = insertStmt.ExecContext(
ctx, ctx,
streamPos, streamPos,
event.RoomID(), event.RoomID(),
@ -326,8 +321,6 @@ func (s *outputRoomEventsStatements) InsertEvent(
excludeFromSync, excludeFromSync,
excludeFromSync, excludeFromSync,
) )
return ierr
})
return streamPos, err return streamPos, err
} }

View file

@ -67,7 +67,6 @@ const selectMaxPositionInTopologySQL = "" +
type outputRoomEventsTopologyStatements struct { type outputRoomEventsTopologyStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
insertEventInTopologyStmt *sql.Stmt insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt
@ -75,10 +74,9 @@ type outputRoomEventsTopologyStatements struct {
selectMaxPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt
} }
func NewSqliteTopologyTable(db *sql.DB, writer sqlutil.Writer) (tables.Topology, error) { func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{ s := &outputRoomEventsTopologyStatements{
db: db, db: db,
writer: writer,
} }
_, err := db.Exec(outputRoomEventsTopologySchema) _, err := db.Exec(outputRoomEventsTopologySchema)
if err != nil { if err != nil {
@ -107,13 +105,11 @@ func NewSqliteTopologyTable(db *sql.DB, writer sqlutil.Writer) (tables.Topology,
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) { ) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
_, err := stmt.ExecContext( _, err = stmt.ExecContext(
ctx, event.EventID(), event.Depth(), event.RoomID(), pos, ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
) )
return err return
})
} }
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(

View file

@ -73,16 +73,14 @@ const deleteSendToDeviceMessagesSQL = `
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt
} }
func NewSqliteSendToDeviceTable(db *sql.DB, writer sqlutil.Writer) (tables.SendToDevice, error) { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{ s := &sendToDeviceStatements{
db: db, db: db,
writer: writer,
} }
_, err := db.Exec(sendToDeviceSchema) _, err := db.Exec(sendToDeviceSchema)
if err != nil { if err != nil {
@ -103,10 +101,8 @@ func NewSqliteSendToDeviceTable(db *sql.DB, writer sqlutil.Writer) (tables.SendT
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
_, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) return
return err
})
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages( func (s *sendToDeviceStatements) CountSendToDeviceMessages(
@ -163,10 +159,8 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
for k, v := range nids { for k, v := range nids {
params[k+1] = v params[k+1] = v
} }
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = txn.ExecContext(ctx, query, params...)
_, err := txn.ExecContext(ctx, query, params...) return
return err
})
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
@ -177,8 +171,6 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
for k, v := range nids { for k, v := range nids {
params[k] = v params[k] = v
} }
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = txn.ExecContext(ctx, query, params...)
_, err := txn.ExecContext(ctx, query, params...) return
return err
})
} }

View file

@ -28,14 +28,12 @@ const selectStreamIDStmt = "" +
type streamIDStatements struct { type streamIDStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
increaseStreamIDStmt *sql.Stmt increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt
} }
func (s *streamIDStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = writer
_, err = db.Exec(streamIDTableSchema) _, err = db.Exec(streamIDTableSchema)
if err != nil { if err != nil {
return return
@ -52,14 +50,9 @@ func (s *streamIDStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err err
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil { return
return ierr
} }
if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
return serr
}
return nil
})
return return
} }

View file

@ -56,43 +56,44 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return err return err
} }
if err = d.streamID.prepare(d.db, d.writer); err != nil { if err = d.streamID.prepare(d.db); err != nil {
return err return err
} }
accountData, err := NewSqliteAccountDataTable(d.db, d.writer, &d.streamID) accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }
events, err := NewSqliteEventsTable(d.db, d.writer, &d.streamID) events, err := NewSqliteEventsTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }
roomState, err := NewSqliteCurrentRoomStateTable(d.db, d.writer, &d.streamID) roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }
invites, err := NewSqliteInvitesTable(d.db, d.writer, &d.streamID) invites, err := NewSqliteInvitesTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }
topology, err := NewSqliteTopologyTable(d.db, d.writer) topology, err := NewSqliteTopologyTable(d.db)
if err != nil { if err != nil {
return err return err
} }
bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db, d.writer) bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db)
if err != nil { if err != nil {
return err return err
} }
sendToDevice, err := NewSqliteSendToDeviceTable(d.db, d.writer) sendToDevice, err := NewSqliteSendToDeviceTable(d.db)
if err != nil { if err != nil {
return err return err
} }
filter, err := NewSqliteFilterTable(d.db, d.writer) filter, err := NewSqliteFilterTable(d.db)
if err != nil { if err != nil {
return err return err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: sqlutil.NewExclusiveWriter(),
Invites: invites, Invites: invites,
AccountData: accountData, AccountData: accountData,
OutputEvents: events, OutputEvents: events,
@ -101,7 +102,6 @@ func (d *SyncServerDatasource) prepare() (err error) {
Topology: topology, Topology: topology,
Filter: filter, Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
SendToDeviceWriter: sqlutil.NewExclusiveWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),
} }
return nil return nil