Use TransactionWriter in other component SQLite (#1209)

* Use TransactionWriter on other component SQLites

* Fix sync API tests

* Fix panic in media API

* Fix a couple of transactions

* Fix wrong query, add some logging output

* Add debug logging into StoreEvent

* Adjust InsertRoomNID

* Update logging
This commit is contained in:
Neil Alexander 2020-07-21 15:48:21 +01:00 committed by GitHub
parent 1d72ce8b7a
commit b6bc132485
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 439 additions and 245 deletions

View file

@ -21,6 +21,7 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -65,6 +66,8 @@ const (
) )
type eventsStatements struct { type eventsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
selectEventsByApplicationServiceIDStmt *sql.Stmt selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
@ -73,6 +76,8 @@ type eventsStatements struct {
} }
func (s *eventsStatements) prepare(db *sql.DB) (err error) { func (s *eventsStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(appserviceEventsSchema) _, err = db.Exec(appserviceEventsSchema)
if err != nil { if err != nil {
return return
@ -217,13 +222,15 @@ func (s *eventsStatements) insertEvent(
return err return err
} }
_, err = s.insertEventStmt.ExecContext( return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
ctx, _, err := s.insertEventStmt.ExecContext(
appServiceID, ctx,
eventJSON, appServiceID,
-1, // No transaction ID yet eventJSON,
) -1, // No transaction ID yet
return )
return err
})
} }
// updateTxnIDForEvents sets the transactionID for a collection of events. Done // updateTxnIDForEvents sets the transactionID for a collection of events. Done
@ -234,8 +241,10 @@ func (s *eventsStatements) updateTxnIDForEvents(
appserviceID string, appserviceID string,
maxID, txnID int, maxID, txnID int,
) (err error) { ) (err error) {
_, err = s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
return err
})
} }
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database. // deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
@ -244,6 +253,8 @@ func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
appserviceID string, appserviceID string,
eventTableID int, eventTableID int,
) (err error) { ) (err error) {
_, err = s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return _, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
return err
})
} }

View file

@ -18,6 +18,8 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const txnIDSchema = ` const txnIDSchema = `
@ -35,10 +37,14 @@ const selectTxnIDSQL = `
` `
type txnStatements struct { type txnStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
selectTxnIDStmt *sql.Stmt selectTxnIDStmt *sql.Stmt
} }
func (s *txnStatements) prepare(db *sql.DB) (err error) { func (s *txnStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(txnIDSchema) _, err = db.Exec(txnIDSchema)
if err != nil { if err != nil {
return return
@ -55,6 +61,9 @@ func (s *txnStatements) prepare(db *sql.DB) (err error) {
func (s *txnStatements) selectTxnID( func (s *txnStatements) selectTxnID(
ctx context.Context, ctx context.Context,
) (txnID int, err error) { ) (txnID int, err error) {
err = s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID) err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
return err
})
return return
} }

View file

@ -68,6 +68,7 @@ const selectBulkStateContentWildSQL = "" +
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
@ -76,7 +77,8 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{ s := &currentRoomStateStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(currentRoomStateSchema) _, err := db.Exec(currentRoomStateSchema)
if err != nil { if err != nil {
@ -125,9 +127,11 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID( func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := stmt.ExecContext(ctx, eventID) stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
return err _, err := stmt.ExecContext(ctx, eventID)
return err
})
} }
func (s *currentRoomStateStatements) UpsertRoomState( func (s *currentRoomStateStatements) UpsertRoomState(
@ -140,18 +144,20 @@ func (s *currentRoomStateStatements) UpsertRoomState(
} }
// upsert state event // upsert state event
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err = stmt.ExecContext( stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
ctx, _, err = stmt.ExecContext(
event.RoomID(), ctx,
event.EventID(), event.RoomID(),
event.Type(), event.EventID(),
event.Sender(), event.Type(),
*event.StateKey(), event.Sender(),
headeredJSON, *event.StateKey(),
contentVal, headeredJSON,
) contentVal,
return err )
return err
})
} }
func (s *currentRoomStateStatements) SelectEventsWithEventIDs( func (s *currentRoomStateStatements) SelectEventsWithEventIDs(

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -60,11 +61,16 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user
` `
type mediaStatements struct { type mediaStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertMediaStmt *sql.Stmt insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt
} }
func (s *mediaStatements) prepare(db *sql.DB) (err error) { func (s *mediaStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(mediaSchema) _, err = db.Exec(mediaSchema)
if err != nil { if err != nil {
return return
@ -80,18 +86,21 @@ func (s *mediaStatements) insertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata, ctx context.Context, mediaMetadata *types.MediaMetadata,
) error { ) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertMediaStmt.ExecContext( return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
ctx, stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
mediaMetadata.MediaID, _, err := stmt.ExecContext(
mediaMetadata.Origin, ctx,
mediaMetadata.ContentType, mediaMetadata.MediaID,
mediaMetadata.FileSizeBytes, mediaMetadata.Origin,
mediaMetadata.CreationTimestamp, mediaMetadata.ContentType,
mediaMetadata.UploadName, mediaMetadata.FileSizeBytes,
mediaMetadata.Base64Hash, mediaMetadata.CreationTimestamp,
mediaMetadata.UserID, mediaMetadata.UploadName,
) mediaMetadata.Base64Hash,
return err mediaMetadata.UserID,
)
return err
})
} }
func (s *mediaStatements) selectMedia( func (s *mediaStatements) selectMedia(

View file

@ -18,6 +18,7 @@ package internal
import ( import (
"context" "context"
"fmt"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -65,13 +66,13 @@ func (r *RoomserverInternalAPI) processRoomEvent(
// Store the event. // Store the event.
roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
if err != nil { if err != nil {
return return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
} }
// if storing this event results in it being redacted then do so. // if storing this event results in it being redacted then do so.
if redactedEventID == event.EventID() { if redactedEventID == event.EventID() {
r, rerr := eventutil.RedactEvent(redactionEvent, &event) r, rerr := eventutil.RedactEvent(redactionEvent, &event)
if rerr != nil { if rerr != nil {
return "", rerr return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr)
} }
event = *r event = *r
} }
@ -93,7 +94,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
// Lets calculate one. // Lets calculate one.
err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event) err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event)
if err != nil { if err != nil {
return return "", fmt.Errorf("r.calculateAndSetState: %w", err)
} }
} }
@ -105,7 +106,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
input.SendAsServer, // send as server input.SendAsServer, // send as server
input.TransactionID, // transaction ID input.TransactionID, // transaction ID
); err != nil { ); err != nil {
return return "", fmt.Errorf("r.updateLatestEvents: %w", err)
} }
// processing this event resulted in an event (which may not be the one we're processing) // processing this event resulted in an event (which may not be the one we're processing)
@ -123,7 +124,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
}, },
}) })
if err != nil { if err != nil {
return return "", fmt.Errorf("r.WriteOutputEvents: %w", err)
} }
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -362,7 +363,7 @@ func (d *Database) StoreEvent(
ctx, txn, txnAndSessionID.TransactionID, ctx, txn, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(), txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil { ); err != nil {
return err return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err)
} }
} }
@ -377,15 +378,15 @@ func (d *Database) StoreEvent(
// room. // room.
var roomVersion gomatrixserverlib.RoomVersion var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return err return fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
} }
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil {
return err return fmt.Errorf("d.assignRoomNID: %w", err)
} }
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
return err return fmt.Errorf("d.assignEventTypeNID: %w", err)
} }
eventStateKey := event.StateKey() eventStateKey := event.StateKey()
@ -393,7 +394,7 @@ func (d *Database) StoreEvent(
// Otherwise set the numeric ID for the state_key to 0. // Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil { if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
return err return fmt.Errorf("d.assignStateKeyNID: %w", err)
} }
} }
@ -411,17 +412,20 @@ func (d *Database) StoreEvent(
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID // We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID()) eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID())
if err != nil {
return fmt.Errorf("d.EventsTable.SelectEvent: %w", err)
}
} }
if err != nil { if err != nil {
return err return fmt.Errorf("d.EventsTable.InsertEvent: %w", err)
} }
} }
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return err return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event)
return err return nil
}) })
if err != nil { if err != nil {
return 0, types.StateAtEvent{}, nil, "", err return 0, types.StateAtEvent{}, nil, "", err

View file

@ -287,7 +287,8 @@ func (s *eventStatements) UpdateEventState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error { ) error {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt)
_, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
return err return err
}) })
} }

View file

@ -71,7 +71,8 @@ func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, roomID string, published bool, ctx context.Context, roomID string, published bool,
) (err error) { ) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published) stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
_, err := stmt.ExecContext(ctx, roomID, published)
return err return err
}) })
} }

View file

@ -87,7 +87,8 @@ func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, alias string, roomID string, creatorUserID string, ctx context.Context, alias string, roomID string, creatorUserID string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
_, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID)
return err return err
}) })
} }
@ -139,7 +140,8 @@ func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias) stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
_, err := stmt.ExecContext(ctx, alias)
return err return err
}) })
} }

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -98,17 +99,23 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
func (s *roomStatements) InsertRoomNID( func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) { ) (roomNID types.RoomNID, err error) {
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
_, err := insertStmt.ExecContext(ctx, roomID, roomVersion) _, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
return err if err != nil {
return fmt.Errorf("insertStmt.ExecContext: %w", err)
}
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
if err != nil {
return fmt.Errorf("s.SelectRoomNID: %w", err)
}
return nil
}) })
if err == nil { if err != nil {
return s.SelectRoomNID(ctx, txn, roomID)
} else {
return types.RoomNID(0), err return types.RoomNID(0), err
} }
return
} }
func (s *roomStatements) SelectRoomNID( func (s *roomStatements) SelectRoomNID(

View file

@ -63,12 +63,14 @@ const upsertServerKeysSQL = "" +
type serverKeyStatements struct { type serverKeyStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
bulkSelectServerKeysStmt *sql.Stmt bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt
} }
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) { func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(serverKeysSchema) _, err = db.Exec(serverKeysSchema)
if err != nil { if err != nil {
return return
@ -136,16 +138,19 @@ func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyLookupRequest, request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult, key gomatrixserverlib.PublicKeyLookupResult,
) error { ) error {
_, err := s.upsertServerKeysStmt.ExecContext( return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
ctx, stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
string(request.ServerName), _, err := stmt.ExecContext(
string(request.KeyID), ctx,
nameAndKeyID(request), string(request.ServerName),
key.ValidUntilTS, string(request.KeyID),
key.ExpiredTS, nameAndKeyID(request),
key.Key.Encode(), key.ValidUntilTS,
) key.ExpiredTS,
return err key.Key.Encode(),
)
return err
})
} }
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string { func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {

View file

@ -281,16 +281,16 @@ func (d *Database) WriteEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
) )
if err != nil { if err != nil {
return err return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
} }
pduPosition = pos pduPosition = pos
if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
return err return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
} }
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
return err return fmt.Errorf("d.handleBackwardExtremities: %w", err)
} }
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
@ -313,7 +313,7 @@ func (d *Database) updateRoomState(
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs { for _, eventID := range removedEventIDs {
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
return err return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
} }
} }
@ -326,13 +326,13 @@ func (d *Database) updateRoomState(
if event.Type() == "m.room.member" { if event.Type() == "m.room.member" {
value, err := event.Membership() value, err := event.Membership()
if err != nil { if err != nil {
return err return fmt.Errorf("event.Membership: %w", err)
} }
membership = &value membership = &value
} }
if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
return err return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
} }
} }

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -49,6 +50,8 @@ const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type" "SELECT MAX(id) FROM syncapi_account_data_type"
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt
@ -57,6 +60,8 @@ type accountDataStatements struct {
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{ s := &accountDataStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(accountDataSchema) _, err := db.Exec(accountDataSchema)
@ -79,12 +84,15 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string, userID, roomID, dataType string,
) (pos types.StreamPosition, err error) { ) (pos types.StreamPosition, err error) {
pos, err = s.streamIDStatements.nextStreamID(ctx, txn) return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
if err != nil { var err error
return pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
} if err != nil {
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) return err
return }
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
return err
})
} }
func (s *accountDataStatements) SelectAccountDataInRange( func (s *accountDataStatements) SelectAccountDataInRange(

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
) )
@ -47,13 +48,18 @@ const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
type backwardExtremitiesStatements struct { type backwardExtremitiesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertBackwardExtremityStmt *sql.Stmt insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt
} }
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{} s := &backwardExtremitiesStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(backwardExtremitiesSchema) _, err := db.Exec(backwardExtremitiesSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -73,8 +79,10 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
return _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return err
})
} }
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
@ -102,6 +110,8 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
return _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err
})
} }

View file

@ -84,6 +84,8 @@ const selectEventsWithEventIDsSQL = "" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)" " FROM syncapi_current_room_state WHERE event_id IN ($1)"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
@ -95,6 +97,8 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{ s := &currentRoomStateStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(currentRoomStateSchema) _, err := db.Exec(currentRoomStateSchema)
@ -196,9 +200,11 @@ func (s *currentRoomStateStatements) SelectCurrentState(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID( func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := stmt.ExecContext(ctx, eventID) stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
return err _, err := stmt.ExecContext(ctx, eventID)
return err
})
} }
func (s *currentRoomStateStatements) UpsertRoomState( func (s *currentRoomStateStatements) UpsertRoomState(
@ -219,20 +225,22 @@ func (s *currentRoomStateStatements) UpsertRoomState(
} }
// upsert state event // upsert state event
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err = stmt.ExecContext( stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
ctx, _, err := stmt.ExecContext(
event.RoomID(), ctx,
event.EventID(), event.RoomID(),
event.Type(), event.EventID(),
event.Sender(), event.Type(),
containsURL, event.Sender(),
*event.StateKey(), containsURL,
headeredJSON, *event.StateKey(),
membership, headeredJSON,
addedAt, membership,
) addedAt,
return err )
return err
})
} }
func (s *currentRoomStateStatements) SelectEventsWithEventIDs( func (s *currentRoomStateStatements) SelectEventsWithEventIDs(

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -50,6 +51,8 @@ const insertFilterSQL = "" +
"INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)" "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
type filterStatements struct { type filterStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
selectFilterStmt *sql.Stmt selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt
@ -60,7 +63,10 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s := &filterStatements{} s := &filterStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err return nil, err
} }
@ -108,30 +114,33 @@ func (s *filterStatements) InsertFilter(
return "", err return "", err
} }
// Check if filter already exists in the database using its localpart and content err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
// // Check if filter already exists in the database using its localpart and content
// This can result in a race condition when two clients try to insert the //
// same filter and localpart at the same time, however this is not a // This can result in a race condition when two clients try to insert the
// problem as both calls will result in the same filterID // same filter and localpart at the same time, however this is not a
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, // problem as both calls will result in the same filterID
localpart, filterJSON).Scan(&existingFilterID) err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
if err != nil && err != sql.ErrNoRows { localpart, filterJSON).Scan(&existingFilterID)
return "", err if err != nil && err != sql.ErrNoRows {
} return err
// If it does, return the existing ID }
if existingFilterID != "" { // If it does, return the existing ID
return existingFilterID, err if existingFilterID != "" {
} return nil
}
// Otherwise insert the filter and return the new ID // Otherwise insert the filter and return the new ID
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
if err != nil { if err != nil {
return "", err return err
} }
rowid, err := res.LastInsertId() rowid, err := res.LastInsertId()
if err != nil { if err != nil {
return "", err return err
} }
filterID = fmt.Sprintf("%d", rowid) filterID = fmt.Sprintf("%d", rowid)
return nil
})
return return
} }

View file

@ -58,6 +58,8 @@ const selectMaxInviteIDSQL = "" +
"SELECT MAX(id) FROM syncapi_invite_events" "SELECT MAX(id) FROM syncapi_invite_events"
type inviteEventsStatements struct { type inviteEventsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertInviteEventStmt *sql.Stmt insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt
@ -67,6 +69,8 @@ type inviteEventsStatements struct {
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{ s := &inviteEventsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(inviteEventsSchema) _, err := db.Exec(inviteEventsSchema)
@ -91,36 +95,45 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
func (s *inviteEventsStatements) InsertInviteEvent( func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) { ) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
if err != nil { var err error
return streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
} if err != nil {
return err
}
var headeredJSON []byte var headeredJSON []byte
headeredJSON, err = json.Marshal(inviteEvent) headeredJSON, err = json.Marshal(inviteEvent)
if err != nil { if err != nil {
return return err
} }
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
ctx, ctx,
streamPos, streamPos,
inviteEvent.RoomID(), inviteEvent.RoomID(),
inviteEvent.EventID(), inviteEvent.EventID(),
*inviteEvent.StateKey(), *inviteEvent.StateKey(),
headeredJSON, headeredJSON,
) )
return err
})
return return
} }
func (s *inviteEventsStatements) DeleteInviteEvent( func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil) var streamPos types.StreamPosition
if err != nil { err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return streamPos, err var err error
} streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) if err != nil {
return err
}
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
return err
})
return streamPos, err return streamPos, err
} }

View file

@ -104,6 +104,8 @@ const selectStateInRangeSQL = "" +
" LIMIT $8" // limit " LIMIT $8" // limit
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
@ -117,6 +119,8 @@ type outputRoomEventsStatements struct {
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{ s := &outputRoomEventsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID, streamIDStatements: streamID,
} }
_, err := db.Exec(outputRoomEventsSchema) _, err := db.Exec(outputRoomEventsSchema)
@ -155,8 +159,10 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
if err != nil { if err != nil {
return err return err
} }
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return err _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
return err
})
} }
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
@ -267,7 +273,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
transactionID *api.TransactionID, excludeFromSync bool, transactionID *api.TransactionID, excludeFromSync bool,
) (streamPos types.StreamPosition, err error) { ) (types.StreamPosition, error) {
var txnID *string var txnID *string
var sessionID *int64 var sessionID *int64
if transactionID != nil { if transactionID != nil {
@ -284,43 +290,47 @@ func (s *outputRoomEventsStatements) InsertEvent(
} }
var headeredJSON []byte var headeredJSON []byte
headeredJSON, err = json.Marshal(event) headeredJSON, err := json.Marshal(event)
if err != nil { if err != nil {
return return 0, err
}
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return
} }
addStateJSON, err := json.Marshal(addState) addStateJSON, err := json.Marshal(addState)
if err != nil { if err != nil {
return return 0, err
} }
removeStateJSON, err := json.Marshal(removeState) removeStateJSON, err := json.Marshal(removeState)
if err != nil { if err != nil {
return return 0, err
} }
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) var streamPos types.StreamPosition
_, err = insertStmt.ExecContext( err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
ctx, streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
streamPos, if err != nil {
event.RoomID(), return err
event.EventID(), }
headeredJSON,
event.Type(), insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
event.Sender(), _, ierr := insertStmt.ExecContext(
containsURL, ctx,
string(addStateJSON), streamPos,
string(removeStateJSON), event.RoomID(),
sessionID, event.EventID(),
txnID, headeredJSON,
excludeFromSync, event.Type(),
excludeFromSync, event.Sender(),
) containsURL,
return string(addStateJSON),
string(removeStateJSON),
sessionID,
txnID,
excludeFromSync,
excludeFromSync,
)
return ierr
})
return streamPos, err
} }
func (s *outputRoomEventsStatements) SelectRecentEvents( func (s *outputRoomEventsStatements) SelectRecentEvents(

View file

@ -66,6 +66,8 @@ const selectMaxPositionInTopologySQL = "" +
" WHERE room_id = $1 ORDER BY stream_position DESC" " WHERE room_id = $1 ORDER BY stream_position DESC"
type outputRoomEventsTopologyStatements struct { type outputRoomEventsTopologyStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertEventInTopologyStmt *sql.Stmt insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt
@ -74,7 +76,10 @@ type outputRoomEventsTopologyStatements struct {
} }
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{} s := &outputRoomEventsTopologyStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(outputRoomEventsTopologySchema) _, err := db.Exec(outputRoomEventsTopologySchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -102,11 +107,13 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err = stmt.ExecContext( stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
ctx, event.EventID(), event.Depth(), event.RoomID(), pos, _, err := stmt.ExecContext(
) ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
return )
return err
})
} }
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(

View file

@ -72,13 +72,18 @@ const deleteSendToDeviceMessagesSQL = `
` `
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt
} }
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{} s := &sendToDeviceStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(sendToDeviceSchema) _, err := db.Exec(sendToDeviceSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -98,8 +103,10 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
return _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
return err
})
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages( func (s *sendToDeviceStatements) CountSendToDeviceMessages(
@ -156,8 +163,10 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
for k, v := range nids { for k, v := range nids {
params[k+1] = v params[k+1] = v
} }
_, err = txn.ExecContext(ctx, query, params...) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
return _, err := txn.ExecContext(ctx, query, params...)
return err
})
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
@ -168,6 +177,8 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
for k, v := range nids { for k, v := range nids {
params[k] = v params[k] = v
} }
_, err = txn.ExecContext(ctx, query, params...) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
return _, err := txn.ExecContext(ctx, query, params...)
return err
})
} }

View file

@ -27,11 +27,15 @@ const selectStreamIDStmt = "" +
"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1" "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
type streamIDStatements struct { type streamIDStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
increaseStreamIDStmt *sql.Stmt increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt
} }
func (s *streamIDStatements) prepare(db *sql.DB) (err error) { func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(streamIDTableSchema) _, err = db.Exec(streamIDTableSchema)
if err != nil { if err != nil {
return return
@ -48,11 +52,14 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil {
} return ierr
if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { }
return if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
} return serr
}
return nil
})
return return
} }

View file

@ -5,6 +5,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os"
"testing" "testing"
"time" "time"
@ -52,7 +53,13 @@ func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.Head
} }
func MustCreateDatabase(t *testing.T) storage.Database { func MustCreateDatabase(t *testing.T) storage.Database {
db, err := sqlite3.NewDatabase("file::memory:") dbname := fmt.Sprintf("test_%s.db", t.Name())
if _, err := os.Stat(dbname); err == nil {
if err = os.Remove(dbname); err != nil {
t.Fatalf("tried to delete stale test database but failed: %s", err)
}
}
db, err := sqlite3.NewDatabase(fmt.Sprintf("file:%s", dbname))
if err != nil { if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err) t.Fatalf("NewSyncServerDatasource returned %s", err)
} }

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -48,12 +50,16 @@ const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(accountDataSchema) _, err = db.Exec(accountDataSchema)
if err != nil { if err != nil {
return return
@ -73,8 +79,10 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
return _, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return err
})
} }
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountData(

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -57,6 +58,8 @@ const selectNewNumericLocalpartSQL = "" +
// TODO: Update password // TODO: Update password
type accountsStatements struct { type accountsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
@ -65,6 +68,8 @@ type accountsStatements struct {
} }
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(accountsSchema) _, err = db.Exec(accountsSchema)
if err != nil { if err != nil {
return return
@ -94,12 +99,15 @@ func (s *accountsStatements) insertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
if appserviceID == "" { var err error
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) if appserviceID == "" {
} else { _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) } else {
} _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
}
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const profilesSchema = ` const profilesSchema = `
@ -46,6 +47,8 @@ const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
type profilesStatements struct { type profilesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertProfileStmt *sql.Stmt insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt setAvatarURLStmt *sql.Stmt
@ -53,6 +56,8 @@ type profilesStatements struct {
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func (s *profilesStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(profilesSchema) _, err = db.Exec(profilesSchema)
if err != nil { if err != nil {
return return
@ -75,8 +80,10 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
return _, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return err
})
} }
func (s *profilesStatements) selectProfileByLocalpart( func (s *profilesStatements) selectProfileByLocalpart(

View file

@ -53,6 +53,8 @@ const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
type threepidStatements struct { type threepidStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
selectLocalpartForThreePIDStmt *sql.Stmt selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt insertThreePIDStmt *sql.Stmt
@ -60,6 +62,8 @@ type threepidStatements struct {
} }
func (s *threepidStatements) prepare(db *sql.DB) (err error) { func (s *threepidStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(threepidSchema) _, err = db.Exec(threepidSchema)
if err != nil { if err != nil {
return return
@ -118,13 +122,18 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
func (s *threepidStatements) insertThreePID( func (s *threepidStatements) insertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err = stmt.ExecContext(ctx, threepid, medium, localpart) stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
return _, err := stmt.ExecContext(ctx, threepid, medium, localpart)
return err
})
} }
func (s *threepidStatements) deleteThreePID( func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) { ctx context.Context, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err := stmt.ExecContext(ctx, threepid, medium)
return err
})
} }

View file

@ -74,6 +74,7 @@ const deleteDevicesSQL = "" +
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
@ -87,6 +88,7 @@ type devicesStatements struct {
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(devicesSchema) _, err = db.Exec(devicesSchema)
if err != nil { if err != nil {
return return
@ -128,13 +130,19 @@ func (s *devicesStatements) insertDevice(
) (*api.Device, error) { ) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64 var sessionID int64
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
return nil, err if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
} return err
sessionID++ }
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err return nil, err
} }
return &api.Device{ return &api.Device{
@ -148,9 +156,11 @@ func (s *devicesStatements) insertDevice(
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id, localpart string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := stmt.ExecContext(ctx, id, localpart) stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
return err _, err := stmt.ExecContext(ctx, id, localpart)
return err
})
} }
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) deleteDevices(
@ -161,31 +171,37 @@ func (s *devicesStatements) deleteDevices(
if err != nil { if err != nil {
return err return err
} }
stmt := sqlutil.TxStmt(txn, prep) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
params := make([]interface{}, len(devices)+1) stmt := sqlutil.TxStmt(txn, prep)
params[0] = localpart params := make([]interface{}, len(devices)+1)
for i, v := range devices { params[0] = localpart
params[i+1] = v for i, v := range devices {
} params[i+1] = v
params = append(params, params...) }
_, err = stmt.ExecContext(ctx, params...) params = append(params, params...)
return err _, err = stmt.ExecContext(ctx, params...)
return err
})
} }
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := stmt.ExecContext(ctx, localpart) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
return err _, err := stmt.ExecContext(ctx, localpart)
return err
})
} }
func (s *devicesStatements) updateDeviceName( func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
return err _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
})
} }
func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) selectDeviceByToken(