Component-wide TransactionWriters (#1290)

* Offset updates take place using TransactionWriter

* Refactor TransactionWriter in current state server

* Refactor TransactionWriter in federation sender

* Refactor TransactionWriter in key server

* Refactor TransactionWriter in media API

* Refactor TransactionWriter in server key API

* Refactor TransactionWriter in sync API

* Refactor TransactionWriter in user API

* Fix deadlocking Sync API tests

* Un-deadlock device database

* Fix appservice API

* Rename TransactionWriters to Writers

* Move writers up a layer in sync API

* Document sqlutil.Writer interface

* Add note to Writer documentation
This commit is contained in:
Neil Alexander 2020-08-21 10:42:08 +01:00 committed by GitHub
parent 5aaf32bbed
commit 9d53351dc2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 483 additions and 483 deletions

View file

@ -32,6 +32,7 @@ type Database struct {
events eventsStatements events eventsStatements
txnID txnStatements txnID txnStatements
db *sql.DB db *sql.DB
writer sqlutil.Writer
} }
// NewDatabase opens a new database // NewDatabase opens a new database
@ -41,10 +42,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if result.db, err = sqlutil.Open(dbProperties); err != nil { if result.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
result.writer = sqlutil.NewDummyWriter()
if err = result.prepare(); err != nil { if err = result.prepare(); err != nil {
return nil, err return nil, err
} }
if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil { if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil {
return nil, err return nil, err
} }
return &result, nil return &result, nil

View file

@ -67,7 +67,7 @@ const (
type eventsStatements struct { type eventsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
selectEventsByApplicationServiceIDStmt *sql.Stmt selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
@ -75,9 +75,9 @@ type eventsStatements struct {
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
} }
func (s *eventsStatements) prepare(db *sql.DB) (err error) { func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(appserviceEventsSchema) _, err = db.Exec(appserviceEventsSchema)
if err != nil { if err != nil {
return return

View file

@ -32,6 +32,7 @@ type Database struct {
events eventsStatements events eventsStatements
txnID txnStatements txnID txnStatements
db *sql.DB db *sql.DB
writer sqlutil.Writer
} }
// NewDatabase opens a new database // NewDatabase opens a new database
@ -41,21 +42,22 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if result.db, err = sqlutil.Open(dbProperties); err != nil { if result.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
result.writer = sqlutil.NewExclusiveWriter()
if err = result.prepare(); err != nil { if err = result.prepare(); err != nil {
return nil, err return nil, err
} }
if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil { if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil {
return nil, err return nil, err
} }
return &result, nil return &result, nil
} }
func (d *Database) prepare() error { func (d *Database) prepare() error {
if err := d.events.prepare(d.db); err != nil { if err := d.events.prepare(d.db, d.writer); err != nil {
return err return err
} }
return d.txnID.prepare(d.db) return d.txnID.prepare(d.db, d.writer)
} }
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database // StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database

View file

@ -38,13 +38,13 @@ const selectTxnIDSQL = `
type txnStatements struct { type txnStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
selectTxnIDStmt *sql.Stmt selectTxnIDStmt *sql.Stmt
} }
func (s *txnStatements) prepare(db *sql.DB) (err error) { func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(txnIDSchema) _, err = db.Exec(txnIDSchema)
if err != nil { if err != nil {
return return

View file

@ -11,6 +11,7 @@ import (
type Database struct { type Database struct {
shared.Database shared.Database
db *sql.DB db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
} }
@ -21,7 +22,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil { d.writer = sqlutil.NewDummyWriter()
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil {
return nil, err return nil, err
} }
currRoomState, err := NewPostgresCurrentRoomStateTable(d.db) currRoomState, err := NewPostgresCurrentRoomStateTable(d.db)
@ -30,6 +32,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer,
CurrentRoomState: currRoomState, CurrentRoomState: currRoomState,
} }
return &d, nil return &d, nil

View file

@ -27,6 +27,7 @@ import (
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer
CurrentRoomState tables.CurrentRoomState CurrentRoomState tables.CurrentRoomState
} }
@ -59,7 +60,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatrixserverlib.HeaderedEvent, func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatrixserverlib.HeaderedEvent,
removeStateEventIDs []string) error { removeStateEventIDs []string) error {
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removeStateEventIDs { for _, eventID := range removeStateEventIDs {
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {

View file

@ -83,7 +83,7 @@ const selectKnownUsersSQL = "" +
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
@ -96,7 +96,7 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{ s := &currentRoomStateStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(), writer: sqlutil.NewExclusiveWriter(),
} }
_, err := db.Exec(currentRoomStateSchema) _, err := db.Exec(currentRoomStateSchema)
if err != nil { if err != nil {

View file

@ -11,6 +11,7 @@ import (
type Database struct { type Database struct {
shared.Database shared.Database
db *sql.DB db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
} }
@ -22,7 +23,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil { d.writer = sqlutil.NewExclusiveWriter()
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil {
return nil, err return nil, err
} }
currRoomState, err := NewSqliteCurrentRoomStateTable(d.db) currRoomState, err := NewSqliteCurrentRoomStateTable(d.db)
@ -31,6 +33,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer,
CurrentRoomState: currRoomState, CurrentRoomState: currRoomState,
} }
return &d, nil return &d, nil

View file

@ -28,6 +28,7 @@ type Database struct {
shared.Database shared.Database
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
db *sql.DB db *sql.DB
writer sqlutil.Writer
} }
// NewDatabase opens a new database // NewDatabase opens a new database
@ -37,6 +38,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
d.writer = sqlutil.NewDummyWriter()
joinedHosts, err := NewPostgresJoinedHostsTable(d.db) joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -63,6 +65,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer,
FederationSenderJoinedHosts: joinedHosts, FederationSenderJoinedHosts: joinedHosts,
FederationSenderQueuePDUs: queuePDUs, FederationSenderQueuePDUs: queuePDUs,
FederationSenderQueueEDUs: queueEDUs, FederationSenderQueueEDUs: queueEDUs,
@ -70,7 +73,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
FederationSenderRooms: rooms, FederationSenderRooms: rooms,
FederationSenderBlacklist: blacklist, FederationSenderBlacklist: blacklist,
} }
if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil

View file

@ -28,6 +28,7 @@ import (
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer
FederationSenderQueuePDUs tables.FederationSenderQueuePDUs FederationSenderQueuePDUs tables.FederationSenderQueuePDUs
FederationSenderQueueEDUs tables.FederationSenderQueueEDUs FederationSenderQueueEDUs tables.FederationSenderQueueEDUs
FederationSenderQueueJSON tables.FederationSenderQueueJSON FederationSenderQueueJSON tables.FederationSenderQueueJSON
@ -64,7 +65,7 @@ func (d *Database) UpdateRoom(
addHosts []types.JoinedHost, addHosts []types.JoinedHost,
removeHosts []string, removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) { ) (joinedHosts []types.JoinedHost, err error) {
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err = d.FederationSenderRooms.InsertRoom(ctx, txn, roomID) err = d.FederationSenderRooms.InsertRoom(ctx, txn, roomID)
if err != nil { if err != nil {
return err return err
@ -133,7 +134,12 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string)
func (d *Database) StoreJSON( func (d *Database) StoreJSON(
ctx context.Context, js string, ctx context.Context, js string,
) (*Receipt, error) { ) (*Receipt, error) {
nid, err := d.FederationSenderQueueJSON.InsertQueueJSON(ctx, nil, js) var nid int64
var err error
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.FederationSenderQueueJSON.InsertQueueJSON(ctx, txn, js)
return nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err) return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
} }
@ -143,11 +149,15 @@ func (d *Database) StoreJSON(
} }
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), nil, serverName) return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
})
} }
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
return d.FederationSenderBlacklist.DeleteBlacklist(context.TODO(), nil, serverName) return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationSenderBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
})
} }
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {

View file

@ -42,7 +42,6 @@ const deleteBlacklistSQL = "" +
type blacklistStatements struct { type blacklistStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
insertBlacklistStmt *sql.Stmt insertBlacklistStmt *sql.Stmt
selectBlacklistStmt *sql.Stmt selectBlacklistStmt *sql.Stmt
deleteBlacklistStmt *sql.Stmt deleteBlacklistStmt *sql.Stmt
@ -51,7 +50,6 @@ type blacklistStatements struct {
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
s = &blacklistStatements{ s = &blacklistStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(blacklistSchema) _, err = db.Exec(blacklistSchema)
if err != nil { if err != nil {
@ -75,11 +73,9 @@ func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
func (s *blacklistStatements) InsertBlacklist( func (s *blacklistStatements) InsertBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
_, err := stmt.ExecContext(ctx, serverName) _, err := stmt.ExecContext(ctx, serverName)
return err return err
})
} }
// selectRoomForUpdate locks the row for the room and returns the last_event_id. // selectRoomForUpdate locks the row for the room and returns the last_event_id.
@ -105,9 +101,7 @@ func (s *blacklistStatements) SelectBlacklist(
func (s *blacklistStatements) DeleteBlacklist( func (s *blacklistStatements) DeleteBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
_, err := stmt.ExecContext(ctx, serverName) _, err := stmt.ExecContext(ctx, serverName)
return err return err
})
} }

View file

@ -65,7 +65,6 @@ const selectJoinedHostsForRoomsSQL = "" +
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
@ -76,7 +75,6 @@ type joinedHostsStatements struct {
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{ s = &joinedHostsStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(joinedHostsSchema) _, err = db.Exec(joinedHostsSchema)
if err != nil { if err != nil {
@ -103,17 +101,14 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
roomID, eventID string, roomID, eventID string,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err return err
})
} }
func (s *joinedHostsStatements) DeleteJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
for _, eventID := range eventIDs { for _, eventID := range eventIDs {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
if _, err := stmt.ExecContext(ctx, eventID); err != nil { if _, err := stmt.ExecContext(ctx, eventID); err != nil {
@ -121,7 +116,6 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
} }
} }
return nil return nil
})
} }
func (s *joinedHostsStatements) SelectJoinedHostsWithTx( func (s *joinedHostsStatements) SelectJoinedHostsWithTx(

View file

@ -64,7 +64,6 @@ const selectQueueServerNamesSQL = "" +
type queueEDUsStatements struct { type queueEDUsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
insertQueueEDUStmt *sql.Stmt insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt
@ -75,7 +74,6 @@ type queueEDUsStatements struct {
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{ s = &queueEDUsStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(queueEDUsSchema) _, err = db.Exec(queueEDUsSchema)
if err != nil { if err != nil {
@ -106,7 +104,6 @@ func (s *queueEDUsStatements) InsertQueueEDU(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nid int64, nid int64,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, ctx,
@ -115,7 +112,6 @@ func (s *queueEDUsStatements) InsertQueueEDU(
nid, // JSON blob NID nid, // JSON blob NID
) )
return err return err
})
} }
func (s *queueEDUsStatements) DeleteQueueEDUs( func (s *queueEDUsStatements) DeleteQueueEDUs(
@ -135,11 +131,9 @@ func (s *queueEDUsStatements) DeleteQueueEDUs(
params[k+1] = v params[k+1] = v
} }
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, deleteStmt) stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err := stmt.ExecContext(ctx, params...) _, err = stmt.ExecContext(ctx, params...)
return err return err
})
} }
func (s *queueEDUsStatements) SelectQueueEDUs( func (s *queueEDUsStatements) SelectQueueEDUs(

View file

@ -50,7 +50,6 @@ const selectJSONSQL = "" +
type queueJSONStatements struct { type queueJSONStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
insertJSONStmt *sql.Stmt insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
@ -59,7 +58,6 @@ type queueJSONStatements struct {
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{ s = &queueJSONStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(queueJSONSchema) _, err = db.Exec(queueJSONSchema)
if err != nil { if err != nil {
@ -74,18 +72,15 @@ func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
func (s *queueJSONStatements) InsertQueueJSON( func (s *queueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string, ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) { ) (lastid int64, err error) {
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
res, err := stmt.ExecContext(ctx, json) res, err := stmt.ExecContext(ctx, json)
if err != nil { if err != nil {
return fmt.Errorf("stmt.QueryContext: %w", err) return 0, fmt.Errorf("stmt.QueryContext: %w", err)
} }
lastid, err = res.LastInsertId() lastid, err = res.LastInsertId()
if err != nil { if err != nil {
return fmt.Errorf("res.LastInsertId: %w", err) return 0, fmt.Errorf("res.LastInsertId: %w", err)
} }
return nil
})
return return
} }
@ -103,11 +98,9 @@ func (s *queueJSONStatements) DeleteQueueJSON(
iNIDs[k] = v iNIDs[k] = v
} }
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, deleteStmt) stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, iNIDs...) _, err = stmt.ExecContext(ctx, iNIDs...)
return err return err
})
} }
func (s *queueJSONStatements) SelectQueueJSON( func (s *queueJSONStatements) SelectQueueJSON(

View file

@ -71,7 +71,6 @@ const selectQueuePDUsServerNamesSQL = "" +
type queuePDUsStatements struct { type queuePDUsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
insertQueuePDUStmt *sql.Stmt insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsByTransactionStmt *sql.Stmt selectQueuePDUsByTransactionStmt *sql.Stmt
@ -84,7 +83,6 @@ type queuePDUsStatements struct {
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{ s = &queuePDUsStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(queuePDUsSchema) _, err = db.Exec(queuePDUsSchema)
if err != nil { if err != nil {
@ -121,7 +119,6 @@ func (s *queuePDUsStatements) InsertQueuePDU(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nid int64, nid int64,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, ctx,
@ -130,7 +127,6 @@ func (s *queuePDUsStatements) InsertQueuePDU(
nid, // JSON blob NID nid, // JSON blob NID
) )
return err return err
})
} }
func (s *queuePDUsStatements) DeleteQueuePDUs( func (s *queuePDUsStatements) DeleteQueuePDUs(
@ -150,11 +146,9 @@ func (s *queuePDUsStatements) DeleteQueuePDUs(
params[k+1] = v params[k+1] = v
} }
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, deleteStmt) stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err := stmt.ExecContext(ctx, params...) _, err = stmt.ExecContext(ctx, params...)
return err return err
})
} }
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(

View file

@ -44,7 +44,6 @@ const updateRoomSQL = "" +
type roomStatements struct { type roomStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
insertRoomStmt *sql.Stmt insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt updateRoomStmt *sql.Stmt
@ -53,7 +52,6 @@ type roomStatements struct {
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) { func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
s = &roomStatements{ s = &roomStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(roomSchema) _, err = db.Exec(roomSchema)
if err != nil { if err != nil {
@ -77,10 +75,8 @@ func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
func (s *roomStatements) InsertRoom( func (s *roomStatements) InsertRoom(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
return err return err
})
} }
// selectRoomForUpdate locks the row for the room and returns the last_event_id. // selectRoomForUpdate locks the row for the room and returns the last_event_id.
@ -103,9 +99,7 @@ func (s *roomStatements) SelectRoomForUpdate(
func (s *roomStatements) UpdateRoom( func (s *roomStatements) UpdateRoom(
ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
_, err := stmt.ExecContext(ctx, roomID, lastEventID) _, err := stmt.ExecContext(ctx, roomID, lastEventID)
return err return err
})
} }

View file

@ -30,6 +30,7 @@ type Database struct {
shared.Database shared.Database
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
db *sql.DB db *sql.DB
writer sqlutil.Writer
} }
// NewDatabase opens a new database // NewDatabase opens a new database
@ -39,6 +40,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
d.writer = sqlutil.NewExclusiveWriter()
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -65,6 +67,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer,
FederationSenderJoinedHosts: joinedHosts, FederationSenderJoinedHosts: joinedHosts,
FederationSenderQueuePDUs: queuePDUs, FederationSenderQueuePDUs: queuePDUs,
FederationSenderQueueEDUs: queueEDUs, FederationSenderQueueEDUs: queueEDUs,
@ -72,7 +75,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
FederationSenderRooms: rooms, FederationSenderRooms: rooms,
FederationSenderBlacklist: blacklist, FederationSenderBlacklist: blacklist,
} }
if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil

View file

@ -53,6 +53,8 @@ const upsertPartitionOffsetsSQL = "" +
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table. // PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
type PartitionOffsetStatements struct { type PartitionOffsetStatements struct {
db *sql.DB
writer Writer
selectPartitionOffsetsStmt *sql.Stmt selectPartitionOffsetsStmt *sql.Stmt
upsertPartitionOffsetStmt *sql.Stmt upsertPartitionOffsetStmt *sql.Stmt
} }
@ -60,7 +62,9 @@ type PartitionOffsetStatements struct {
// Prepare converts the raw SQL statements into prepared statements. // Prepare converts the raw SQL statements into prepared statements.
// Takes a prefix to prepend to the table name used to store the partition offsets. // Takes a prefix to prepend to the table name used to store the partition offsets.
// This allows multiple components to share the same database schema. // This allows multiple components to share the same database schema.
func (s *PartitionOffsetStatements) Prepare(db *sql.DB, prefix string) (err error) { func (s *PartitionOffsetStatements) Prepare(db *sql.DB, writer Writer, prefix string) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(strings.Replace(partitionOffsetsSchema, "${prefix}", prefix, -1)) _, err = db.Exec(strings.Replace(partitionOffsetsSchema, "${prefix}", prefix, -1))
if err != nil { if err != nil {
return return
@ -121,6 +125,9 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
func (s *PartitionOffsetStatements) upsertPartitionOffset( func (s *PartitionOffsetStatements) upsertPartitionOffset(
ctx context.Context, topic string, partition int32, offset int64, ctx context.Context, topic string, partition int32, offset int64,
) error { ) error {
_, err := s.upsertPartitionOffsetStmt.ExecContext(ctx, topic, partition, offset) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := TxStmt(txn, s.upsertPartitionOffsetStmt)
_, err := stmt.ExecContext(ctx, topic, partition, offset)
return err return err
})
} }

View file

@ -103,7 +103,3 @@ func SQLiteDriverName() string {
} }
return "sqlite3" return "sqlite3"
} }
type TransactionWriter interface {
Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error
}

View file

@ -0,0 +1,46 @@
package sqlutil
import "database/sql"
// The Writer interface is designed to solve the problem of how
// to handle database writes for database engines that don't allow
// concurrent writes, e.g. SQLite.
//
// The interface has a single Do function which takes an optional
// database parameter, an optional transaction parameter and a
// required function parameter. The Writer will call the function
// provided when it is safe to do so, optionally providing a
// transaction to use.
//
// Depending on the combination of parameters provided, the Writer
// will behave in one of three ways:
//
// 1. `db` provided, `txn` provided:
//
// The Writer will call f() when it is safe to do so. The supplied
// "txn" will ALWAYS be passed through to f(). Use this when you
// already have a transaction open.
//
// 2. `db` provided, `txn` not provided (nil):
//
// The Writer will open a new transaction on the provided database
// and then will call f() when it is safe to do so. The new
// transaction will ALWAYS be passed through to f(). Use this if
// you plan to perform more than one SQL query within f().
//
// 3. `db` not provided (nil), `txn` not provided (nil):
//
// The Writer will call f() when it is safe to do so, but will
// not make any attempt to open a new database transaction or to
// pass through an existing one. The "txn" parameter within f()
// will ALWAYS be nil in this mode. This is useful if you just
// want to perform a single query on an already-prepared statement
// without the overhead of opening a new transaction to do it in.
//
// You MUST take particular care not to call Do() from within f()
// on the same Writer, or it will likely result in a deadlock.
type Writer interface {
// Queue up one or more database write operations within the
// provided function to be executed when it is safe to do so.
Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error
}

View file

@ -4,15 +4,21 @@ import (
"database/sql" "database/sql"
) )
type DummyTransactionWriter struct { // DummyWriter implements sqlutil.Writer.
// The DummyWriter is designed to allow reuse of the sqlutil.Writer
// interface but, unlike ExclusiveWriter, it will not guarantee
// writer exclusivity. This is fine in PostgreSQL where overlapping
// transactions and writes are acceptable.
type DummyWriter struct {
} }
func NewDummyTransactionWriter() TransactionWriter { // NewDummyWriter returns a new dummy writer.
return &DummyTransactionWriter{} func NewDummyWriter() Writer {
return &DummyWriter{}
} }
func (w *DummyTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { func (w *DummyWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
if txn == nil { if db != nil && txn == nil {
return WithTransaction(db, func(txn *sql.Tx) error { return WithTransaction(db, func(txn *sql.Tx) error {
return f(txn) return f(txn)
}) })

View file

@ -7,16 +7,17 @@ import (
"go.uber.org/atomic" "go.uber.org/atomic"
) )
// ExclusiveTransactionWriter allows queuing database writes so that you don't // ExclusiveWriter implements sqlutil.Writer.
// ExclusiveWriter allows queuing database writes so that you don't
// contend on database locks in, e.g. SQLite. Only one task will run // contend on database locks in, e.g. SQLite. Only one task will run
// at a time on a given ExclusiveTransactionWriter. // at a time on a given ExclusiveWriter.
type ExclusiveTransactionWriter struct { type ExclusiveWriter struct {
running atomic.Bool running atomic.Bool
todo chan transactionWriterTask todo chan transactionWriterTask
} }
func NewTransactionWriter() TransactionWriter { func NewExclusiveWriter() Writer {
return &ExclusiveTransactionWriter{ return &ExclusiveWriter{
todo: make(chan transactionWriterTask), todo: make(chan transactionWriterTask),
} }
} }
@ -34,7 +35,7 @@ type transactionWriterTask struct {
// txn parameter if one is supplied, and if not, will take out a // txn parameter if one is supplied, and if not, will take out a
// new transaction from the database supplied in the database // new transaction from the database supplied in the database
// parameter. Either way, this will block until the task is done. // parameter. Either way, this will block until the task is done.
func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { func (w *ExclusiveWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
if w.todo == nil { if w.todo == nil {
return errors.New("not initialised") return errors.New("not initialised")
} }
@ -55,20 +56,20 @@ func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql
// of these goroutines will run at a time. A transaction will be // of these goroutines will run at a time. A transaction will be
// opened using the database object from the task and then this will // opened using the database object from the task and then this will
// be passed as a parameter to the task function. // be passed as a parameter to the task function.
func (w *ExclusiveTransactionWriter) run() { func (w *ExclusiveWriter) run() {
if !w.running.CAS(false, true) { if !w.running.CAS(false, true) {
return return
} }
defer w.running.Store(false) defer w.running.Store(false)
for task := range w.todo { for task := range w.todo {
if task.txn != nil { if task.db != nil && task.txn != nil {
task.wait <- task.f(task.txn) task.wait <- task.f(task.txn)
} else if task.db != nil { } else if task.db != nil && task.txn == nil {
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
return task.f(txn) return task.f(txn)
}) })
} else { } else {
panic("expected database or transaction but got neither") task.wait <- task.f(nil)
} }
close(task.wait) close(task.wait)
} }

View file

@ -63,7 +63,7 @@ const deleteAllDeviceKeysSQL = "" +
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
@ -71,10 +71,10 @@ type deviceKeysStatements struct {
deleteAllDeviceKeysStmt *sql.Stmt deleteAllDeviceKeysStmt *sql.Stmt
} }
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewSqliteDeviceKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{ s := &deviceKeysStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(), writer: writer,
} }
_, err := db.Exec(deviceKeysSchema) _, err := db.Exec(deviceKeysSchema)
if err != nil { if err != nil {

View file

@ -52,15 +52,15 @@ const selectKeyChangesSQL = "" +
type keyChangesStatements struct { type keyChangesStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
upsertKeyChangeStmt *sql.Stmt upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt
} }
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { func NewSqliteKeyChangesTable(db *sql.DB, writer sqlutil.Writer) (tables.KeyChanges, error) {
s := &keyChangesStatements{ s := &keyChangesStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(), writer: writer,
} }
_, err := db.Exec(keyChangesSchema) _, err := db.Exec(keyChangesSchema)
if err != nil { if err != nil {

View file

@ -60,7 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
type oneTimeKeysStatements struct { type oneTimeKeysStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
upsertKeysStmt *sql.Stmt upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt selectKeysCountStmt *sql.Stmt
@ -68,10 +68,10 @@ type oneTimeKeysStatements struct {
deleteOneTimeKeyStmt *sql.Stmt deleteOneTimeKeyStmt *sql.Stmt
} }
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { func NewSqliteOneTimeKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.OneTimeKeys, error) {
s := &oneTimeKeysStatements{ s := &oneTimeKeysStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(), writer: writer,
} }
_, err := db.Exec(oneTimeKeysSchema) _, err := db.Exec(oneTimeKeysSchema)
if err != nil { if err != nil {

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -49,13 +50,18 @@ const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
type staleDeviceListsStatements struct { type staleDeviceListsStatements struct {
db *sql.DB
writer sqlutil.Writer
upsertStaleDeviceListStmt *sql.Stmt upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt
} }
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { func NewSqliteStaleDeviceListsTable(db *sql.DB, writer sqlutil.Writer) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{} s := &staleDeviceListsStatements{
db: db,
writer: writer,
}
_, err := db.Exec(staleDeviceListsSchema) _, err := db.Exec(staleDeviceListsSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -77,8 +83,11 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
if err != nil { if err != nil {
return err return err
} }
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.upsertStaleDeviceListStmt)
_, err = stmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
return err return err
})
} }
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {

View file

@ -25,19 +25,20 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
otk, err := NewSqliteOneTimeKeysTable(db) writer := sqlutil.NewExclusiveWriter()
otk, err := NewSqliteOneTimeKeysTable(db, writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dk, err := NewSqliteDeviceKeysTable(db) dk, err := NewSqliteDeviceKeysTable(db, writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
kc, err := NewSqliteKeyChangesTable(db) kc, err := NewSqliteKeyChangesTable(db, writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sdl, err := NewSqliteStaleDeviceListsTable(db) sdl, err := NewSqliteStaleDeviceListsTable(db, writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -62,14 +62,14 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user
type mediaStatements struct { type mediaStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
insertMediaStmt *sql.Stmt insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt
} }
func (s *mediaStatements) prepare(db *sql.DB) (err error) { func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(mediaSchema) _, err = db.Exec(mediaSchema)
if err != nil { if err != nil {

View file

@ -17,6 +17,8 @@ package sqlite3
import ( import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
type statements struct { type statements struct {
@ -24,11 +26,11 @@ type statements struct {
thumbnail thumbnailStatements thumbnail thumbnailStatements
} }
func (s *statements) prepare(db *sql.DB) (err error) { func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
if err = s.media.prepare(db); err != nil { if err = s.media.prepare(db, writer); err != nil {
return return
} }
if err = s.thumbnail.prepare(db); err != nil { if err = s.thumbnail.prepare(db, writer); err != nil {
return return
} }

View file

@ -31,16 +31,19 @@ import (
type Database struct { type Database struct {
statements statements statements statements
db *sql.DB db *sql.DB
writer sqlutil.Writer
} }
// Open opens a postgres database. // Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) { func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
var d Database d := Database{
writer: sqlutil.NewExclusiveWriter(),
}
var err error var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
if err = d.statements.prepare(d.db); err != nil { if err = d.statements.prepare(d.db, d.writer); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil

View file

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -57,16 +58,20 @@ SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method
` `
type thumbnailStatements struct { type thumbnailStatements struct {
db *sql.DB
writer sqlutil.Writer
insertThumbnailStmt *sql.Stmt insertThumbnailStmt *sql.Stmt
selectThumbnailStmt *sql.Stmt selectThumbnailStmt *sql.Stmt
selectThumbnailsStmt *sql.Stmt selectThumbnailsStmt *sql.Stmt
} }
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) { func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
_, err = db.Exec(thumbnailSchema) _, err = db.Exec(thumbnailSchema)
if err != nil { if err != nil {
return return
} }
s.db = db
s.writer = writer
return statementList{ return statementList{
{&s.insertThumbnailStmt, insertThumbnailSQL}, {&s.insertThumbnailStmt, insertThumbnailSQL},
@ -79,7 +84,9 @@ func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error { ) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertThumbnailStmt.ExecContext( return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
_, err := stmt.ExecContext(
ctx, ctx,
thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin, thumbnailMetadata.MediaMetadata.Origin,
@ -91,6 +98,7 @@ func (s *thumbnailStatements) insertThumbnail(
thumbnailMetadata.ThumbnailSize.ResizeMethod, thumbnailMetadata.ThumbnailSize.ResizeMethod,
) )
return err return err
})
} }
func (s *thumbnailStatements) selectThumbnail( func (s *thumbnailStatements) selectThumbnail(

View file

@ -98,7 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db, DB: db,
Writer: sqlutil.NewDummyTransactionWriter(), Writer: sqlutil.NewDummyWriter(),
EventTypesTable: eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys, EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON, EventJSONTable: eventJSON,

View file

@ -27,7 +27,7 @@ const redactionsArePermanent = false
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.TransactionWriter Writer sqlutil.Writer
EventsTable tables.Events EventsTable tables.Events
EventJSONTable tables.EventJSON EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes EventTypesTable tables.EventTypes

View file

@ -41,7 +41,7 @@ type Database struct {
invites tables.Invites invites tables.Invites
membership tables.Membership membership tables.Membership
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
} }
// Open a sqlite database. // Open a sqlite database.
@ -52,7 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
d.writer = sqlutil.NewTransactionWriter() d.writer = sqlutil.NewExclusiveWriter()
//d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA journal_mode=WAL;")
//d.db.Exec("PRAGMA read_uncommitted = true;") //d.db.Exec("PRAGMA read_uncommitted = true;")
@ -120,7 +120,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: sqlutil.NewTransactionWriter(), Writer: sqlutil.NewExclusiveWriter(),
EventsTable: d.events, EventsTable: d.events,
EventTypesTable: d.eventTypes, EventTypesTable: d.eventTypes,
EventStateKeysTable: d.eventStateKeys, EventStateKeysTable: d.eventStateKeys,

View file

@ -30,6 +30,7 @@ import (
// A Database implements gomatrixserverlib.KeyDatabase and is used to store // A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers. // the public keys for other matrix servers.
type Database struct { type Database struct {
writer sqlutil.Writer
statements serverKeyStatements statements serverKeyStatements
} }
@ -47,8 +48,10 @@ func NewDatabase(
if err != nil { if err != nil {
return nil, err return nil, err
} }
d := &Database{} d := &Database{
err = d.statements.prepare(db) writer: sqlutil.NewExclusiveWriter(),
}
err = d.statements.prepare(db, d.writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -63,14 +63,14 @@ const upsertServerKeysSQL = "" +
type serverKeyStatements struct { type serverKeyStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
bulkSelectServerKeysStmt *sql.Stmt bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt
} }
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) { func (s *serverKeyStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(serverKeysSchema) _, err = db.Exec(serverKeysSchema)
if err != nil { if err != nil {
return return

View file

@ -31,6 +31,7 @@ import (
type SyncServerDatasource struct { type SyncServerDatasource struct {
shared.Database shared.Database
db *sql.DB db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
} }
@ -41,7 +42,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { d.writer = sqlutil.NewDummyWriter()
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return nil, err return nil, err
} }
accountData, err := NewPostgresAccountDataTable(d.db) accountData, err := NewPostgresAccountDataTable(d.db)
@ -78,6 +80,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: sqlutil.NewDummyWriter(),
Invites: invites, Invites: invites,
AccountData: accountData, AccountData: accountData,
OutputEvents: events, OutputEvents: events,
@ -86,7 +89,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
BackwardExtremities: backwardExtremities, BackwardExtremities: backwardExtremities,
Filter: filter, Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),
} }
return &d, nil return &d, nil

View file

@ -37,6 +37,7 @@ import (
// For now this contains the shared functions // For now this contains the shared functions
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer
Invites tables.Invites Invites tables.Invites
AccountData tables.AccountData AccountData tables.AccountData
OutputEvents tables.Events OutputEvents tables.Events
@ -45,7 +46,6 @@ type Database struct {
BackwardExtremities tables.BackwardsExtremities BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice SendToDevice tables.SendToDevice
Filter tables.Filter Filter tables.Filter
SendToDeviceWriter sqlutil.TransactionWriter
EDUCache *cache.EDUCache EDUCache *cache.EDUCache
} }
@ -129,10 +129,7 @@ func (d *Database) GetStateEvent(
func (d *Database) GetStateEventsForRoom( func (d *Database) GetStateEventsForRoom(
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { ) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) {
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter)
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter)
return err
})
return return
} }
@ -171,9 +168,9 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition
func (d *Database) AddInviteEvent( func (d *Database) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) sp, err = d.Invites.InsertInviteEvent(ctx, nil, inviteEvent)
return err return nil
}) })
return return
} }
@ -182,8 +179,12 @@ func (d *Database) AddInviteEvent(
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *Database) RetireInviteEvent( func (d *Database) RetireInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) (types.StreamPosition, error) { ) (sp types.StreamPosition, err error) {
return d.Invites.DeleteInviteEvent(ctx, inviteEventID) _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
sp, err = d.Invites.DeleteInviteEvent(ctx, inviteEventID)
return nil
})
return
} }
// GetAccountDataInRange returns all account data for a given user inserted or // GetAccountDataInRange returns all account data for a given user inserted or
@ -207,7 +208,7 @@ func (d *Database) GetAccountDataInRange(
func (d *Database) UpsertAccountData( func (d *Database) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string, ctx context.Context, userID, roomID, dataType string,
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
return err return err
}) })
@ -237,6 +238,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
return err return err
@ -275,7 +277,7 @@ func (d *Database) WriteEvent(
addStateEventIDs, removeStateEventIDs []string, addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, excludeFromSync bool, transactionID *api.TransactionID, excludeFromSync bool,
) (pduPosition types.StreamPosition, returnErr error) { ) (pduPosition types.StreamPosition, returnErr error) {
returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error var err error
pos, err := d.OutputEvents.InsertEvent( pos, err := d.OutputEvents.InsertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
@ -304,6 +306,7 @@ func (d *Database) WriteEvent(
return pduPosition, returnErr return pduPosition, returnErr
} }
// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) updateRoomState( func (d *Database) updateRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
removedEventIDs []string, removedEventIDs []string,
@ -1114,7 +1117,7 @@ func (d *Database) StoreNewSendForDeviceMessage(
} }
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee // Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place. // that we don't lock the table for writes in more than one place.
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AddSendToDeviceEvent( return d.AddSendToDeviceEvent(
ctx, txn, userID, deviceID, string(j), ctx, txn, userID, deviceID, string(j),
) )
@ -1179,7 +1182,7 @@ func (d *Database) CleanSendToDeviceUpdates(
// If we need to write to the database then we'll ask the SendToDeviceWriter to // If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in // do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place. // more than one place.
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion. // Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)

View file

@ -20,7 +20,6 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -51,7 +50,6 @@ const selectMaxAccountDataIDSQL = "" +
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt
@ -61,7 +59,6 @@ type accountDataStatements struct {
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{ s := &accountDataStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(accountDataSchema) _, err := db.Exec(accountDataSchema)
@ -84,15 +81,12 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string, userID, roomID, dataType string,
) (pos types.StreamPosition, err error) { ) (pos types.StreamPosition, err error) {
return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
var err error
pos, err = s.streamIDStatements.nextStreamID(ctx, txn) pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil { if err != nil {
return err return
} }
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
return err return
})
} }
func (s *accountDataStatements) SelectAccountDataInRange( func (s *accountDataStatements) SelectAccountDataInRange(

View file

@ -19,7 +19,6 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
) )
@ -49,7 +48,6 @@ const deleteBackwardExtremitySQL = "" +
type backwardExtremitiesStatements struct { type backwardExtremitiesStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
insertBackwardExtremityStmt *sql.Stmt insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt
@ -58,7 +56,6 @@ type backwardExtremitiesStatements struct {
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{ s := &backwardExtremitiesStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(backwardExtremitiesSchema) _, err := db.Exec(backwardExtremitiesSchema)
if err != nil { if err != nil {
@ -79,10 +76,8 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
_, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return err return err
})
} }
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
@ -110,8 +105,6 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
_, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err return err
})
} }

View file

@ -85,7 +85,6 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
@ -98,7 +97,6 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{ s := &currentRoomStateStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(currentRoomStateSchema) _, err := db.Exec(currentRoomStateSchema)
@ -200,11 +198,9 @@ func (s *currentRoomStateStatements) SelectCurrentState(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID( func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
_, err := stmt.ExecContext(ctx, eventID) _, err := stmt.ExecContext(ctx, eventID)
return err return err
})
} }
func (s *currentRoomStateStatements) UpsertRoomState( func (s *currentRoomStateStatements) UpsertRoomState(
@ -225,9 +221,8 @@ func (s *currentRoomStateStatements) UpsertRoomState(
} }
// upsert state event // upsert state event
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
_, err := stmt.ExecContext( _, err = stmt.ExecContext(
ctx, ctx,
event.RoomID(), event.RoomID(),
event.EventID(), event.EventID(),
@ -240,7 +235,6 @@ func (s *currentRoomStateStatements) UpsertRoomState(
addedAt, addedAt,
) )
return err return err
})
} }
func minOfInts(a, b int) int { func minOfInts(a, b int) int {

View file

@ -20,7 +20,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -52,7 +51,6 @@ const insertFilterSQL = "" +
type filterStatements struct { type filterStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
selectFilterStmt *sql.Stmt selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt
@ -65,7 +63,6 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
} }
s := &filterStatements{ s := &filterStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err return nil, err
@ -114,7 +111,6 @@ func (s *filterStatements) InsertFilter(
return "", err return "", err
} }
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
// Check if filter already exists in the database using its localpart and content // Check if filter already exists in the database using its localpart and content
// //
// This can result in a race condition when two clients try to insert the // This can result in a race condition when two clients try to insert the
@ -123,24 +119,22 @@ func (s *filterStatements) InsertFilter(
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID) localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return "", err
} }
// If it does, return the existing ID // If it does, return the existing ID
if existingFilterID != "" { if existingFilterID != "" {
return nil return existingFilterID, nil
} }
// Otherwise insert the filter and return the new ID // Otherwise insert the filter and return the new ID
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
if err != nil { if err != nil {
return err return "", err
} }
rowid, err := res.LastInsertId() rowid, err := res.LastInsertId()
if err != nil { if err != nil {
return err return "", err
} }
filterID = fmt.Sprintf("%d", rowid) filterID = fmt.Sprintf("%d", rowid)
return nil
})
return return
} }

View file

@ -59,7 +59,6 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct { type inviteEventsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertInviteEventStmt *sql.Stmt insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt
@ -70,7 +69,6 @@ type inviteEventsStatements struct {
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{ s := &inviteEventsStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(inviteEventsSchema) _, err := db.Exec(inviteEventsSchema)
@ -95,20 +93,19 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
func (s *inviteEventsStatements) InsertInviteEvent( func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) { ) (streamPos types.StreamPosition, err error) {
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
var err error
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil { if err != nil {
return err return
} }
var headeredJSON []byte var headeredJSON []byte
headeredJSON, err = json.Marshal(inviteEvent) headeredJSON, err = json.Marshal(inviteEvent)
if err != nil { if err != nil {
return err return
} }
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
_, err = stmt.ExecContext(
ctx, ctx,
streamPos, streamPos,
inviteEvent.RoomID(), inviteEvent.RoomID(),
@ -116,24 +113,17 @@ func (s *inviteEventsStatements) InsertInviteEvent(
*inviteEvent.StateKey(), *inviteEvent.StateKey(),
headeredJSON, headeredJSON,
) )
return err
})
return return
} }
func (s *inviteEventsStatements) DeleteInviteEvent( func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
var streamPos types.StreamPosition streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
var err error
streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
if err != nil { if err != nil {
return err return streamPos, err
} }
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
return err
})
return streamPos, err return streamPos, err
} }

View file

@ -105,7 +105,6 @@ const selectStateInRangeSQL = "" +
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
@ -120,7 +119,6 @@ type outputRoomEventsStatements struct {
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{ s := &outputRoomEventsStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(outputRoomEventsSchema) _, err := db.Exec(outputRoomEventsSchema)
@ -159,10 +157,8 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
if err != nil { if err != nil {
return err return err
} }
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
return err return err
})
} }
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
@ -304,15 +300,12 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err return 0, err
} }
var streamPos types.StreamPosition streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil { if err != nil {
return err return 0, err
} }
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
_, ierr := insertStmt.ExecContext( _, err = insertStmt.ExecContext(
ctx, ctx,
streamPos, streamPos,
event.RoomID(), event.RoomID(),
@ -328,8 +321,6 @@ func (s *outputRoomEventsStatements) InsertEvent(
excludeFromSync, excludeFromSync,
excludeFromSync, excludeFromSync,
) )
return ierr
})
return streamPos, err return streamPos, err
} }

View file

@ -67,7 +67,6 @@ const selectMaxPositionInTopologySQL = "" +
type outputRoomEventsTopologyStatements struct { type outputRoomEventsTopologyStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
insertEventInTopologyStmt *sql.Stmt insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt
@ -78,7 +77,6 @@ type outputRoomEventsTopologyStatements struct {
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{ s := &outputRoomEventsTopologyStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(outputRoomEventsTopologySchema) _, err := db.Exec(outputRoomEventsTopologySchema)
if err != nil { if err != nil {
@ -107,13 +105,11 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) { ) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
_, err := stmt.ExecContext( _, err = stmt.ExecContext(
ctx, event.EventID(), event.Depth(), event.RoomID(), pos, ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
) )
return err return
})
} }
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(

View file

@ -73,7 +73,6 @@ const deleteSendToDeviceMessagesSQL = `
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt
@ -82,7 +81,6 @@ type sendToDeviceStatements struct {
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{ s := &sendToDeviceStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(sendToDeviceSchema) _, err := db.Exec(sendToDeviceSchema)
if err != nil { if err != nil {
@ -103,10 +101,8 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
_, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) return
return err
})
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages( func (s *sendToDeviceStatements) CountSendToDeviceMessages(
@ -163,10 +159,8 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
for k, v := range nids { for k, v := range nids {
params[k+1] = v params[k+1] = v
} }
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = txn.ExecContext(ctx, query, params...)
_, err := txn.ExecContext(ctx, query, params...) return
return err
})
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
@ -177,8 +171,6 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
for k, v := range nids { for k, v := range nids {
params[k] = v params[k] = v
} }
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err = txn.ExecContext(ctx, query, params...)
_, err := txn.ExecContext(ctx, query, params...) return
return err
})
} }

View file

@ -28,14 +28,12 @@ const selectStreamIDStmt = "" +
type streamIDStatements struct { type streamIDStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
increaseStreamIDStmt *sql.Stmt increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt
} }
func (s *streamIDStatements) prepare(db *sql.DB) (err error) { func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(streamIDTableSchema) _, err = db.Exec(streamIDTableSchema)
if err != nil { if err != nil {
return return
@ -52,14 +50,9 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil { return
return ierr }
} err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
return serr
}
return nil
})
return return
} }

View file

@ -32,6 +32,7 @@ import (
type SyncServerDatasource struct { type SyncServerDatasource struct {
shared.Database shared.Database
db *sql.DB db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
streamID streamIDStatements streamID streamIDStatements
} }
@ -44,6 +45,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
d.writer = sqlutil.NewExclusiveWriter()
if err = d.prepare(); err != nil { if err = d.prepare(); err != nil {
return nil, err return nil, err
} }
@ -51,7 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
} }
func (d *SyncServerDatasource) prepare() (err error) { func (d *SyncServerDatasource) prepare() (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return err return err
} }
if err = d.streamID.prepare(d.db); err != nil { if err = d.streamID.prepare(d.db); err != nil {
@ -91,6 +93,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: sqlutil.NewExclusiveWriter(),
Invites: invites, Invites: invites,
AccountData: accountData, AccountData: accountData,
OutputEvents: events, OutputEvents: events,
@ -99,7 +102,6 @@ func (d *SyncServerDatasource) prepare() (err error) {
Topology: topology, Topology: topology,
Filter: filter, Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),
} }
return nil return nil

View file

@ -35,6 +35,7 @@ import (
// Database represents an account database // Database represents an account database
type Database struct { type Database struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
accounts accountsStatements accounts accountsStatements
profiles profilesStatements profiles profilesStatements
@ -49,27 +50,27 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil { if err != nil {
return nil, err return nil, err
} }
partitions := sqlutil.PartitionOffsetStatements{} d := &Database{
if err = partitions.Prepare(db, "account"); err != nil { serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
}
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
return nil, err return nil, err
} }
a := accountsStatements{} if err = d.accounts.prepare(db, serverName); err != nil {
if err = a.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
p := profilesStatements{} if err = d.profiles.prepare(db); err != nil {
if err = p.prepare(db); err != nil {
return nil, err return nil, err
} }
ac := accountDataStatements{} if err = d.accountDatas.prepare(db); err != nil {
if err = ac.prepare(db); err != nil {
return nil, err return nil, err
} }
t := threepidStatements{} if err = d.threepids.prepare(db); err != nil {
if err = t.prepare(db); err != nil {
return nil, err return nil, err
} }
return &Database{db, partitions, a, p, ac, t, serverName}, nil return d, nil
} }
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.

View file

@ -51,15 +51,15 @@ const selectAccountDataByTypeSQL = "" +
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(accountDataSchema) _, err = db.Exec(accountDataSchema)
if err != nil { if err != nil {
return return

View file

@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" +
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
@ -67,9 +67,9 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(accountsSchema) _, err = db.Exec(accountsSchema)
if err != nil { if err != nil {
return return

View file

@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" +
type profilesStatements struct { type profilesStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
insertProfileStmt *sql.Stmt insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt setAvatarURLStmt *sql.Stmt
@ -61,9 +61,9 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(profilesSchema) _, err = db.Exec(profilesSchema)
if err != nil { if err != nil {
return return

View file

@ -34,6 +34,8 @@ import (
// Database represents an account database // Database represents an account database
type Database struct { type Database struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
accounts accountsStatements accounts accountsStatements
profiles profilesStatements profiles profilesStatements
@ -53,35 +55,28 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil { if err != nil {
return nil, err return nil, err
} }
partitions := sqlutil.PartitionOffsetStatements{} d := &Database{
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
return &Database{
db: db,
PartitionOffsetStatements: partitions,
accounts: a,
profiles: p,
accountDatas: ac,
threepids: t,
serverName: serverName, serverName: serverName,
}, nil db: db,
writer: sqlutil.NewExclusiveWriter(),
}
partitions := sqlutil.PartitionOffsetStatements{}
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, d.writer, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db, d.writer); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db, d.writer); err != nil {
return nil, err
}
if err = d.threepids.prepare(db, d.writer); err != nil {
return nil, err
}
return d, nil
} }
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.

View file

@ -54,16 +54,16 @@ const deleteThreePIDSQL = "" +
type threepidStatements struct { type threepidStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
selectLocalpartForThreePIDStmt *sql.Stmt selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt insertThreePIDStmt *sql.Stmt
deleteThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt
} }
func (s *threepidStatements) prepare(db *sql.DB) (err error) { func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(threepidSchema) _, err = db.Exec(threepidSchema)
if err != nil { if err != nil {
return return

View file

@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" +
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter writer sqlutil.Writer
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
@ -91,9 +91,9 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(devicesSchema) _, err = db.Exec(devicesSchema)
if err != nil { if err != nil {
return return
@ -138,19 +138,13 @@ func (s *devicesStatements) insertDevice(
) (*api.Device, error) { ) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64 var sessionID int64
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return err return nil, err
} }
sessionID++ sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err return nil, err
} }
return &api.Device{ return &api.Device{
@ -164,11 +158,9 @@ func (s *devicesStatements) insertDevice(
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id, localpart string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart) _, err := stmt.ExecContext(ctx, id, localpart)
return err return err
})
} }
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) deleteDevices(
@ -179,7 +171,6 @@ func (s *devicesStatements) deleteDevices(
if err != nil { if err != nil {
return err return err
} }
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, prep) stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1) params := make([]interface{}, len(devices)+1)
params[0] = localpart params[0] = localpart
@ -188,27 +179,22 @@ func (s *devicesStatements) deleteDevices(
} }
_, err = stmt.ExecContext(ctx, params...) _, err = stmt.ExecContext(ctx, params...)
return err return err
})
} }
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart) _, err := stmt.ExecContext(ctx, localpart)
return err return err
})
} }
func (s *devicesStatements) updateDeviceName( func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err return err
})
} }
func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) selectDeviceByToken(

View file

@ -34,6 +34,7 @@ var deviceIDByteLength = 6
// Database represents a device database. // Database represents a device database.
type Database struct { type Database struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
devices devicesStatements devices devicesStatements
} }
@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil { if err != nil {
return nil, err return nil, err
} }
writer := sqlutil.NewExclusiveWriter()
d := devicesStatements{} d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil { if err = d.prepare(db, writer, serverName); err != nil {
return nil, err return nil, err
} }
return &Database{db, d}, nil return &Database{db, writer, d}, nil
} }
// GetDeviceByAccessToken returns the device matching the given access token. // GetDeviceByAccessToken returns the device matching the given access token.
@ -88,7 +90,7 @@ func (d *Database) CreateDevice(
displayName *string, displayName *string,
) (dev *api.Device, returnErr error) { ) (dev *api.Device, returnErr error) {
if deviceID != nil { if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error var err error
// Revoke existing tokens for this device // Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
@ -108,7 +110,7 @@ func (d *Database) CreateDevice(
return return
} }
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err return err
@ -138,7 +140,7 @@ func generateDeviceID() (string, error) {
func (d *Database) UpdateDevice( func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string, ctx context.Context, localpart, deviceID string, displayName *string,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
}) })
} }
@ -150,7 +152,7 @@ func (d *Database) UpdateDevice(
func (d *Database) RemoveDevice( func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string, ctx context.Context, deviceID, localpart string,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err return err
} }
@ -165,7 +167,7 @@ func (d *Database) RemoveDevice(
func (d *Database) RemoveDevices( func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string, ctx context.Context, localpart string, devices []string,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err return err
} }
@ -179,7 +181,7 @@ func (d *Database) RemoveDevices(
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err return err
} }