From 3d58417555ef6a1308fec8b83e752016765b72f2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 19 Aug 2020 10:57:29 +0100 Subject: [PATCH 1/6] Roomserver database-wide TransactionWriters (#1282) * Database-wide TransactionWriter * Fix deadlocking Sync API tests * Undo non-roomserver changes for now --- .../storage/sqlite3/event_json_table.go | 4 +-- .../storage/sqlite3/event_state_keys_table.go | 4 +-- .../storage/sqlite3/event_types_table.go | 4 +-- roomserver/storage/sqlite3/events_table.go | 4 +-- roomserver/storage/sqlite3/invite_table.go | 6 ++-- .../storage/sqlite3/membership_table.go | 4 +-- .../storage/sqlite3/previous_events_table.go | 4 +-- roomserver/storage/sqlite3/published_table.go | 4 +-- .../storage/sqlite3/redactions_table.go | 4 +-- .../storage/sqlite3/room_aliases_table.go | 4 +-- roomserver/storage/sqlite3/rooms_table.go | 4 +-- .../storage/sqlite3/state_block_table.go | 4 +-- .../storage/sqlite3/state_snapshot_table.go | 4 +-- roomserver/storage/sqlite3/storage.go | 29 ++++++++++--------- .../storage/sqlite3/transactions_table.go | 4 +-- 15 files changed, 44 insertions(+), 43 deletions(-) diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 64795d024..e8118ad76 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -54,10 +54,10 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { +func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventJSON, error) { s := &eventJSONStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(eventJSONSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 3e9f2e613..c8ad052bf 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -71,10 +71,10 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { +func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(eventStateKeysSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index fd4a2e42f..4a645789d 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -85,10 +85,10 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { +func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventTypes, error) { s := &eventTypeStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(eventTypesSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index b3cfee07e..3ac30ca3d 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -115,10 +115,10 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { +func NewSqliteEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Events, error) { s := &eventStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(eventsSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index e806eab6d..1305f4a8a 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -71,10 +71,10 @@ type inviteStatements struct { selectInvitesAboutToRetireStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { +func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Invites, error) { s := &inviteStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(inviteSchema) if err != nil { @@ -124,7 +124,7 @@ func (s *inviteStatements) UpdateInviteRetired( if err != nil { return err } - defer (func() { err = rows.Close() })() + defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed") for rows.Next() { var inviteEventID string if err = rows.Scan(&inviteEventID); err != nil { diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 6dd8bd83f..7b69cee32 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -88,10 +88,10 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { +func NewSqliteMembershipTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Membership, error) { s := &membershipStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(membershipSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 28b5d18f0..ff804861c 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -59,10 +59,10 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { +func NewSqlitePrevEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.PreviousEvents, error) { s := &previousEventStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(previousEventSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index 85f1e0a49..a4a47aec9 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -51,10 +51,10 @@ type publishedStatements struct { selectPublishedStmt *sql.Stmt } -func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { +func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Published, error) { s := &publishedStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(publishedSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go index d2bd2a204..ad900a4ec 100644 --- a/roomserver/storage/sqlite3/redactions_table.go +++ b/roomserver/storage/sqlite3/redactions_table.go @@ -60,10 +60,10 @@ type redactionStatements struct { markRedactionValidatedStmt *sql.Stmt } -func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { +func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Redactions, error) { s := &redactionStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(redactionsSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 4a5357776..deba3ff55 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -65,10 +65,10 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { +func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.RoomAliases, error) { s := &roomAliasesStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(roomAliasesSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index bb30a63b3..8bbec5080 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -76,10 +76,10 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { +func NewSqliteRoomsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Rooms, error) { s := &roomStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(roomsSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 3d716b642..3e28e450b 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -81,10 +81,10 @@ type stateBlockStatements struct { bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func NewSqliteStateBlockTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateBlock, error) { s := &stateBlockStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(stateDataSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 48f1210be..799904ff6 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -55,10 +55,10 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func NewSqliteStateSnapshotTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(stateSnapshotSchema) if err != nil { diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 048de1928..ae3140d7d 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -51,6 +51,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } + writer := sqlutil.NewTransactionWriter() //d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA read_uncommitted = true;") @@ -60,59 +61,59 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { // which it will never obtain. d.db.SetMaxOpenConns(20) - d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) + d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db, writer) if err != nil { return nil, err } - d.eventTypes, err = NewSqliteEventTypesTable(d.db) + d.eventTypes, err = NewSqliteEventTypesTable(d.db, writer) if err != nil { return nil, err } - d.eventJSON, err = NewSqliteEventJSONTable(d.db) + d.eventJSON, err = NewSqliteEventJSONTable(d.db, writer) if err != nil { return nil, err } - d.events, err = NewSqliteEventsTable(d.db) + d.events, err = NewSqliteEventsTable(d.db, writer) if err != nil { return nil, err } - d.rooms, err = NewSqliteRoomsTable(d.db) + d.rooms, err = NewSqliteRoomsTable(d.db, writer) if err != nil { return nil, err } - d.transactions, err = NewSqliteTransactionsTable(d.db) + d.transactions, err = NewSqliteTransactionsTable(d.db, writer) if err != nil { return nil, err } - stateBlock, err := NewSqliteStateBlockTable(d.db) + stateBlock, err := NewSqliteStateBlockTable(d.db, writer) if err != nil { return nil, err } - stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) + stateSnapshot, err := NewSqliteStateSnapshotTable(d.db, writer) if err != nil { return nil, err } - d.prevEvents, err = NewSqlitePrevEventsTable(d.db) + d.prevEvents, err = NewSqlitePrevEventsTable(d.db, writer) if err != nil { return nil, err } - roomAliases, err := NewSqliteRoomAliasesTable(d.db) + roomAliases, err := NewSqliteRoomAliasesTable(d.db, writer) if err != nil { return nil, err } - d.invites, err = NewSqliteInvitesTable(d.db) + d.invites, err = NewSqliteInvitesTable(d.db, writer) if err != nil { return nil, err } - d.membership, err = NewSqliteMembershipTable(d.db) + d.membership, err = NewSqliteMembershipTable(d.db, writer) if err != nil { return nil, err } - published, err := NewSqlitePublishedTable(d.db) + published, err := NewSqlitePublishedTable(d.db, writer) if err != nil { return nil, err } - redactions, err := NewSqliteRedactionsTable(d.db) + redactions, err := NewSqliteRedactionsTable(d.db, writer) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 2f6cff95a..65c18a8a9 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -50,10 +50,10 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { +func NewSqliteTransactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Transactions, error) { s := &transactionStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(transactionsSchema) if err != nil { From 775b04d776ddc06fdee5ece6a407008f00edb7f2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 19 Aug 2020 13:24:54 +0100 Subject: [PATCH 2/6] Roomserver updater changes (#1283) * Take input transaction when setting up updaters * Fix nil pointer exceptions * Rename room recent events updater to latest events updater * Contd rename room recent events updater to latest events updater * Remove unnecessary interfaces for latest events and membership updaters --- roomserver/internal/input_latest_events.go | 3 +- roomserver/internal/input_membership.go | 11 ++-- roomserver/storage/interface.go | 5 +- ...ts_updater.go => latest_events_updater.go} | 34 ++++------ .../storage/shared/membership_updater.go | 46 ++++---------- roomserver/storage/shared/storage.go | 16 +++-- roomserver/storage/sqlite3/events_table.go | 14 ++++- roomserver/storage/sqlite3/storage.go | 8 +-- roomserver/types/types.go | 63 ------------------- 9 files changed, 63 insertions(+), 137 deletions(-) rename roomserver/storage/shared/{room_recent_events_updater.go => latest_events_updater.go} (68%) diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index 66316ac4f..0158c8f7f 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -92,7 +93,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( type latestEventsUpdater struct { ctx context.Context api *RoomserverInternalAPI - updater types.RoomRecentEventsUpdater + updater *shared.LatestEventsUpdater roomNID types.RoomNID stateAtEvent types.StateAtEvent event gomatrixserverlib.Event diff --git a/roomserver/internal/input_membership.go b/roomserver/internal/input_membership.go index af0c7f8b3..bcecfca0e 100644 --- a/roomserver/internal/input_membership.go +++ b/roomserver/internal/input_membership.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -29,7 +30,7 @@ import ( // consumers about the invites added or retired by the change in current state. func (r *RoomserverInternalAPI) updateMemberships( ctx context.Context, - updater types.RoomRecentEventsUpdater, + updater *shared.LatestEventsUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { changes := membershipChanges(removed, added) @@ -77,7 +78,7 @@ func (r *RoomserverInternalAPI) updateMemberships( } func (r *RoomserverInternalAPI) updateMembership( - updater types.RoomRecentEventsUpdater, + updater *shared.LatestEventsUpdater, targetUserNID types.EventStateKeyNID, remove, add *gomatrixserverlib.Event, updates []api.OutputEvent, @@ -141,7 +142,7 @@ func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bo } func updateToInviteMembership( - mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, + mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, roomVersion gomatrixserverlib.RoomVersion, ) ([]api.OutputEvent, error) { // We may have already sent the invite to the user, either because we are @@ -171,7 +172,7 @@ func updateToInviteMembership( } func updateToJoinMembership( - mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, + mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { // If the user is already marked as being joined, we call SetToJoin to update // the event ID then we can return immediately. Retired is ignored as there @@ -207,7 +208,7 @@ func updateToJoinMembership( } func updateToLeaveMembership( - mu types.MembershipUpdater, add *gomatrixserverlib.Event, + mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, newMembership string, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { // If the user is already neither joined, nor invited to the room then we diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index afe5bcb1f..988fc908d 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -18,6 +18,7 @@ import ( "context" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -86,7 +87,7 @@ type Database interface { // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // Returns the latest events in the room and the last eventID sent to the log along with an updater. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) + GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (*shared.LatestEventsUpdater, error) // Look up event ID by transaction's info. // This is used to determine if the room event is processed/processing already. // Returns an empty string if no such event exists. @@ -123,7 +124,7 @@ type Database interface { // Returns an error if there was a problem talking to the database. RemoveRoomAlias(ctx context.Context, alias string) error // Build a membership updater for the target user in a room. - MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) + MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (*shared.MembershipUpdater, error) // Lookup the membership of a given user in a given room. // Returns the numeric ID of the latest membership event sent from this user // in this room, along a boolean set to true if the user is still in this room, diff --git a/roomserver/storage/shared/room_recent_events_updater.go b/roomserver/storage/shared/latest_events_updater.go similarity index 68% rename from roomserver/storage/shared/room_recent_events_updater.go rename to roomserver/storage/shared/latest_events_updater.go index 8131f712d..21b168a4f 100644 --- a/roomserver/storage/shared/room_recent_events_updater.go +++ b/roomserver/storage/shared/latest_events_updater.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -type roomRecentEventsUpdater struct { +type LatestEventsUpdater struct { transaction d *Database roomNID types.RoomNID @@ -17,11 +17,7 @@ type roomRecentEventsUpdater struct { currentStateSnapshotNID types.StateSnapshotNID } -func NewRoomRecentEventsUpdater(d *Database, ctx context.Context, roomNID types.RoomNID, useTxns bool) (types.RoomRecentEventsUpdater, error) { - txn, err := d.DB.Begin() - if err != nil { - return nil, err - } +func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomNID types.RoomNID) (*LatestEventsUpdater, error) { eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) if err != nil { @@ -41,38 +37,34 @@ func NewRoomRecentEventsUpdater(d *Database, ctx context.Context, roomNID types. return nil, err } } - if !useTxns { - txn.Commit() // nolint: errcheck - txn = nil - } - return &roomRecentEventsUpdater{ + return &LatestEventsUpdater{ transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, }, nil } // RoomVersion implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { +func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) return } // LatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { +func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference { return u.latestEvents } // LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LastEventIDSent() string { +func (u *LatestEventsUpdater) LastEventIDSent() string { return u.lastEventIDSent } // CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { +func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { return u.currentStateSnapshotNID } // StorePreviousEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { +func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { for _, ref := range previousEventReferences { if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { return err @@ -82,7 +74,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p } // IsReferenced implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { +func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) if err == nil { return true, nil @@ -94,7 +86,7 @@ func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib. } // SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) SetLatestEvents( +func (u *LatestEventsUpdater) SetLatestEvents( roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, currentStateSnapshotNID types.StateSnapshotNID, ) error { @@ -106,15 +98,15 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( } // HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { +func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) } // MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { +func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, u.txn, eventNID) } -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { +func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) } diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 5ddf6d84d..5955844f9 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -9,7 +9,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -type membershipUpdater struct { +type MembershipUpdater struct { transaction d *Database roomNID types.RoomNID @@ -18,21 +18,9 @@ type membershipUpdater struct { } func NewMembershipUpdater( - ctx context.Context, d *Database, roomID, targetUserID string, + ctx context.Context, d *Database, txn *sql.Tx, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, - useTxns bool, -) (types.MembershipUpdater, error) { - txn, err := d.DB.Begin() - if err != nil { - return nil, err - } - succeeded := false - defer func() { - if !succeeded { - txn.Rollback() // nolint: errcheck - } - }() - +) (*MembershipUpdater, error) { roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) if err != nil { return nil, err @@ -43,17 +31,7 @@ func NewMembershipUpdater( return nil, err } - updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) - if err != nil { - return nil, err - } - - succeeded = true - if !useTxns { - txn.Commit() // nolint: errcheck - updater.transaction.txn = nil - } - return updater, nil + return d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) } func (d *Database) membershipUpdaterTxn( @@ -62,7 +40,7 @@ func (d *Database) membershipUpdaterTxn( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, targetLocal bool, -) (*membershipUpdater, error) { +) (*MembershipUpdater, error) { if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err @@ -73,28 +51,28 @@ func (d *Database) membershipUpdaterTxn( return nil, err } - return &membershipUpdater{ + return &MembershipUpdater{ transaction{ctx, txn}, d, roomNID, targetUserNID, membership, }, nil } // IsInvite implements types.MembershipUpdater -func (u *membershipUpdater) IsInvite() bool { +func (u *MembershipUpdater) IsInvite() bool { return u.membership == tables.MembershipStateInvite } // IsJoin implements types.MembershipUpdater -func (u *membershipUpdater) IsJoin() bool { +func (u *MembershipUpdater) IsJoin() bool { return u.membership == tables.MembershipStateJoin } // IsLeave implements types.MembershipUpdater -func (u *membershipUpdater) IsLeave() bool { +func (u *MembershipUpdater) IsLeave() bool { return u.membership == tables.MembershipStateLeaveOrBan } // SetToInvite implements types.MembershipUpdater -func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { +func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) if err != nil { return false, err @@ -116,7 +94,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er } // SetToJoin implements types.MembershipUpdater -func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { +func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { var inviteEventIDs []string senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) @@ -153,7 +131,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } // SetToLeave implements types.MembershipUpdater -func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { +func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { return nil, err diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 5494e4654..00179e336 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -332,14 +332,22 @@ func (d *Database) GetTransactionEventID( func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, -) (types.MembershipUpdater, error) { - return NewMembershipUpdater(ctx, d, roomID, targetUserID, targetLocal, roomVersion, true) +) (*MembershipUpdater, error) { + txn, err := d.DB.Begin() + if err != nil { + return nil, err + } + return NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion) } func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, -) (types.RoomRecentEventsUpdater, error) { - return NewRoomRecentEventsUpdater(d, ctx, roomNID, true) +) (*LatestEventsUpdater, error) { + txn, err := d.DB.Begin() + if err != nil { + return nil, err + } + return NewLatestEventsUpdater(ctx, d, txn, roomNID) } func (d *Database) StoreEvent( diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 3ac30ca3d..0e39755cb 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -326,9 +326,13 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } ////////////// - rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) + rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) if err != nil { return nil, err } @@ -372,7 +376,7 @@ func (s *eventStatements) BulkSelectEventReference( iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } @@ -471,7 +475,11 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, iEventIDs[i] = v } sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) - err := txn.QueryRowContext(ctx, sqlStr, iEventIDs...).Scan(&result) + sqlPrep, err := s.db.Prepare(sqlStr) + if err != nil { + return 0, err + } + err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result) if err != nil { return 0, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index ae3140d7d..724316373 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -139,25 +139,25 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, -) (types.RoomRecentEventsUpdater, error) { +) (*shared.LatestEventsUpdater, error) { // TODO: Do not use transactions. We should be holding open this transaction but we cannot have // multiple write transactions on sqlite. The code will perform additional // write transactions independent of this one which will consistently cause // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewRoomRecentEventsUpdater(&d.Database, ctx, roomNID, false) + return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomNID) } func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, -) (updater types.MembershipUpdater, err error) { +) (*shared.MembershipUpdater, error) { // TODO: Do not use transactions. We should be holding open this transaction but we cannot have // multiple write transactions on sqlite. The code will perform additional // write transactions independent of this one which will consistently cause // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewMembershipUpdater(ctx, &d.Database, roomID, targetUserID, targetLocal, roomVersion, false) + return shared.NewMembershipUpdater(ctx, &d.Database, nil, roomID, targetUserID, targetLocal, roomVersion) } diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 241e1e15d..cf4a86b66 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -16,7 +16,6 @@ package types import ( - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -140,68 +139,6 @@ type StateEntryList struct { StateEntries []StateEntry } -// A RoomRecentEventsUpdater is used to update the recent events in a room. -// (On postgresql this wraps a database transaction that holds a "FOR UPDATE" -// lock on the row in the rooms table holding the latest events for the room.) -type RoomRecentEventsUpdater interface { - // The room version of the room. - RoomVersion() gomatrixserverlib.RoomVersion - // The latest event IDs and state in the room. - LatestEvents() []StateAtEventAndReference - // The event ID of the latest event written to the output log in the room. - LastEventIDSent() string - // The current state of the room. - CurrentStateSnapshotNID() StateSnapshotNID - // Store the previous events referenced by an event. - // This adds the event NID to an entry in the database for each of the previous events. - // If there isn't an entry for one of previous events then an entry is created. - // If the entry already lists the event NID as a referrer then the entry unmodified. - // (i.e. the operation is idempotent) - StorePreviousEvents(eventNID EventNID, previousEventReferences []gomatrixserverlib.EventReference) error - // Check whether the eventReference is already referenced by another matrix event. - IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) - // Set the list of latest events for the room. - // This replaces the current list stored in the database with the given list - SetLatestEvents( - roomNID RoomNID, latest []StateAtEventAndReference, lastEventNIDSent EventNID, - currentStateSnapshotNID StateSnapshotNID, - ) error - // Check if the event has already be written to the output logs. - HasEventBeenSent(eventNID EventNID) (bool, error) - // Mark the event as having been sent to the output logs. - MarkEventAsSent(eventNID EventNID) error - // Build a membership updater for the target user in this room. - // It will share the same transaction as this updater. - MembershipUpdater(targetUserNID EventStateKeyNID, isTargetLocalUser bool) (MembershipUpdater, error) - // Implements Transaction so it can be committed or rolledback - sqlutil.Transaction -} - -// A MembershipUpdater is used to update the membership of a user in a room. -// (On postgresql this wraps a database transaction that holds a "FOR UPDATE" -// lock on the row in the membership table for this user in the room) -// The caller should call one of SetToInvite, SetToJoin or SetToLeave once to -// make the update, or none of them if no update is required. -type MembershipUpdater interface { - // True if the target user is invited to the room before updating. - IsInvite() bool - // True if the target user is joined to the room before updating. - IsJoin() bool - // True if the target user is not invited or joined to the room before updating. - IsLeave() bool - // Set the state to invite. - // Returns whether this invite needs to be sent - SetToInvite(event gomatrixserverlib.Event) (needsSending bool, err error) - // Set the state to join or updates the event ID in the database. - // Returns a list of invite event IDs that this state change retired. - SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) - // Set the state to leave. - // Returns a list of invite event IDs that this state change retired. - SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) - // Implements Transaction so it can be committed or rolledback. - sqlutil.Transaction -} - // A MissingEventError is an error that happened because the roomserver was // missing requested events from its database. type MissingEventError string From b24747b305a0770fdd746655e702aa1c1c049765 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 19 Aug 2020 15:38:27 +0100 Subject: [PATCH 3/6] Transaction writer changes, move roomserver writers (#1285) * Updated TransactionWriters, moved locks in roomserver, various other tweaks * Fix redaction deadlocks * Fix lint issue * Rename SQLiteTransactionWriter to ExclusiveTransactionWriter * Fix us not sending transactions through in latest events updater --- .../sqlite3/appservice_events_table.go | 2 +- .../storage/sqlite3/txn_id_counter_table.go | 2 +- .../sqlite3/current_room_state_table.go | 2 +- .../storage/postgres/blacklist_table.go | 20 ++--- .../storage/sqlite3/blacklist_table.go | 2 +- .../storage/sqlite3/joined_hosts_table.go | 2 +- .../storage/sqlite3/queue_edus_table.go | 2 +- .../storage/sqlite3/queue_json_table.go | 2 +- .../storage/sqlite3/queue_pdus_table.go | 2 +- .../storage/sqlite3/room_table.go | 2 +- internal/sqlutil/sql.go | 71 +----------------- internal/sqlutil/writer_dummy.go | 22 ++++++ internal/sqlutil/writer_exclusive.go | 75 +++++++++++++++++++ .../storage/sqlite3/device_keys_table.go | 2 +- .../storage/sqlite3/key_changes_table.go | 2 +- .../storage/sqlite3/one_time_keys_table.go | 2 +- .../storage/sqlite3/media_repository_table.go | 2 +- roomserver/internal/input_latest_events.go | 36 +++++---- roomserver/state/state.go | 22 ++++-- roomserver/storage/postgres/storage.go | 1 + .../storage/shared/latest_events_updater.go | 26 +++++-- .../storage/shared/membership_updater.go | 34 +++++---- roomserver/storage/shared/storage.go | 43 +++++++---- .../storage/sqlite3/event_json_table.go | 12 +-- .../storage/sqlite3/event_state_keys_table.go | 28 +++---- .../storage/sqlite3/event_types_table.go | 27 ++++--- roomserver/storage/sqlite3/events_table.go | 54 ++++++------- roomserver/storage/sqlite3/invite_table.go | 66 +++++++--------- .../storage/sqlite3/membership_table.go | 26 +++---- .../storage/sqlite3/previous_events_table.go | 18 ++--- roomserver/storage/sqlite3/published_table.go | 16 ++-- .../storage/sqlite3/redactions_table.go | 22 ++---- .../storage/sqlite3/room_aliases_table.go | 25 ++----- roomserver/storage/sqlite3/rooms_table.go | 46 +++++------- .../storage/sqlite3/state_block_table.go | 37 ++++----- .../storage/sqlite3/state_snapshot_table.go | 29 +++---- roomserver/storage/sqlite3/storage.go | 32 ++++---- .../storage/sqlite3/transactions_table.go | 20 ++--- .../storage/sqlite3/server_key_table.go | 2 +- syncapi/storage/shared/syncserver.go | 2 +- syncapi/storage/sqlite3/account_data_table.go | 2 +- .../sqlite3/backwards_extremities_table.go | 2 +- .../sqlite3/current_room_state_table.go | 2 +- syncapi/storage/sqlite3/filter_table.go | 2 +- syncapi/storage/sqlite3/invites_table.go | 2 +- .../sqlite3/output_room_events_table.go | 2 +- .../output_room_events_topology_table.go | 2 +- .../storage/sqlite3/send_to_device_table.go | 2 +- syncapi/storage/sqlite3/stream_id_table.go | 2 +- .../accounts/sqlite3/account_data_table.go | 2 +- .../accounts/sqlite3/accounts_table.go | 2 +- .../storage/accounts/sqlite3/profile_table.go | 2 +- .../accounts/sqlite3/threepid_table.go | 2 +- .../storage/devices/sqlite3/devices_table.go | 2 +- 54 files changed, 432 insertions(+), 434 deletions(-) create mode 100644 internal/sqlutil/writer_dummy.go create mode 100644 internal/sqlutil/writer_exclusive.go diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go index da31f2359..5cc07ed34 100644 --- a/appservice/storage/sqlite3/appservice_events_table.go +++ b/appservice/storage/sqlite3/appservice_events_table.go @@ -67,7 +67,7 @@ const ( type eventsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter selectEventsByApplicationServiceIDStmt *sql.Stmt countEventsByApplicationServiceIDStmt *sql.Stmt insertEventStmt *sql.Stmt diff --git a/appservice/storage/sqlite3/txn_id_counter_table.go b/appservice/storage/sqlite3/txn_id_counter_table.go index 501ab5aa7..0ae0feeea 100644 --- a/appservice/storage/sqlite3/txn_id_counter_table.go +++ b/appservice/storage/sqlite3/txn_id_counter_table.go @@ -38,7 +38,7 @@ const selectTxnIDSQL = ` type txnStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter selectTxnIDStmt *sql.Stmt } diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index 5c7e8b0a7..9d2fe6e04 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -83,7 +83,7 @@ const selectKnownUsersSQL = "" + type currentRoomStateStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt diff --git a/federationsender/storage/postgres/blacklist_table.go b/federationsender/storage/postgres/blacklist_table.go index 8de6feec3..f92c59e54 100644 --- a/federationsender/storage/postgres/blacklist_table.go +++ b/federationsender/storage/postgres/blacklist_table.go @@ -42,7 +42,6 @@ const deleteBlacklistSQL = "" + type blacklistStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertBlacklistStmt *sql.Stmt selectBlacklistStmt *sql.Stmt deleteBlacklistStmt *sql.Stmt @@ -50,8 +49,7 @@ type blacklistStatements struct { func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { s = &blacklistStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err = db.Exec(blacklistSchema) if err != nil { @@ -75,11 +73,9 @@ func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { func (s *blacklistStatements) InsertBlacklist( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) - _, err := stmt.ExecContext(ctx, serverName) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err } // selectRoomForUpdate locks the row for the room and returns the last_event_id. @@ -105,9 +101,7 @@ func (s *blacklistStatements) SelectBlacklist( func (s *blacklistStatements) DeleteBlacklist( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) - _, err := stmt.ExecContext(ctx, serverName) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err } diff --git a/federationsender/storage/sqlite3/blacklist_table.go b/federationsender/storage/sqlite3/blacklist_table.go index a14fe0c40..b23bfcba4 100644 --- a/federationsender/storage/sqlite3/blacklist_table.go +++ b/federationsender/storage/sqlite3/blacklist_table.go @@ -42,7 +42,7 @@ const deleteBlacklistSQL = "" + type blacklistStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertBlacklistStmt *sql.Stmt selectBlacklistStmt *sql.Stmt deleteBlacklistStmt *sql.Stmt diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 53736fa16..5dc18f4ec 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -65,7 +65,7 @@ const selectJoinedHostsForRoomsSQL = "" + type joinedHostsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt diff --git a/federationsender/storage/sqlite3/queue_edus_table.go b/federationsender/storage/sqlite3/queue_edus_table.go index cd11a0ea8..2abcc105d 100644 --- a/federationsender/storage/sqlite3/queue_edus_table.go +++ b/federationsender/storage/sqlite3/queue_edus_table.go @@ -64,7 +64,7 @@ const selectQueueServerNamesSQL = "" + type queueEDUsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt diff --git a/federationsender/storage/sqlite3/queue_json_table.go b/federationsender/storage/sqlite3/queue_json_table.go index 46dfd9ab1..867ffd44b 100644 --- a/federationsender/storage/sqlite3/queue_json_table.go +++ b/federationsender/storage/sqlite3/queue_json_table.go @@ -50,7 +50,7 @@ const selectJSONSQL = "" + type queueJSONStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertJSONStmt *sql.Stmt //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go index 1474bfc02..538ba3db8 100644 --- a/federationsender/storage/sqlite3/queue_pdus_table.go +++ b/federationsender/storage/sqlite3/queue_pdus_table.go @@ -71,7 +71,7 @@ const selectQueuePDUsServerNamesSQL = "" + type queuePDUsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertQueuePDUStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsByTransactionStmt *sql.Stmt diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go index 517938745..9a439fada 100644 --- a/federationsender/storage/sqlite3/room_table.go +++ b/federationsender/storage/sqlite3/room_table.go @@ -44,7 +44,7 @@ const updateRoomSQL = "" + type roomStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertRoomStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt updateRoomStmt *sql.Stmt diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 95467c636..002d77183 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -19,8 +19,6 @@ import ( "errors" "fmt" "runtime" - - "go.uber.org/atomic" ) // ErrUserExists is returned if a username already exists in the database. @@ -52,7 +50,7 @@ func EndTransaction(txn Transaction, succeeded *bool) error { func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { txn, err := db.Begin() if err != nil { - return + return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err) } succeeded := false defer func() { @@ -106,69 +104,6 @@ func SQLiteDriverName() string { return "sqlite3" } -// TransactionWriter allows queuing database writes so that you don't -// contend on database locks in, e.g. SQLite. Only one task will run -// at a time on a given TransactionWriter. -type TransactionWriter struct { - running atomic.Bool - todo chan transactionWriterTask -} - -func NewTransactionWriter() *TransactionWriter { - return &TransactionWriter{ - todo: make(chan transactionWriterTask), - } -} - -// transactionWriterTask represents a specific task. -type transactionWriterTask struct { - db *sql.DB - txn *sql.Tx - f func(txn *sql.Tx) error - wait chan error -} - -// Do queues a task to be run by a TransactionWriter. The function -// provided will be ran within a transaction as supplied by the -// txn parameter if one is supplied, and if not, will take out a -// new transaction from the database supplied in the database -// parameter. Either way, this will block until the task is done. -func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { - if w.todo == nil { - return errors.New("not initialised") - } - if !w.running.Load() { - go w.run() - } - task := transactionWriterTask{ - db: db, - txn: txn, - f: f, - wait: make(chan error, 1), - } - w.todo <- task - return <-task.wait -} - -// run processes the tasks for a given transaction writer. Only one -// of these goroutines will run at a time. A transaction will be -// opened using the database object from the task and then this will -// be passed as a parameter to the task function. -func (w *TransactionWriter) run() { - if !w.running.CAS(false, true) { - return - } - defer w.running.Store(false) - for task := range w.todo { - if task.txn != nil { - task.wait <- task.f(task.txn) - } else if task.db != nil { - task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { - return task.f(txn) - }) - } else { - panic("expected database or transaction but got neither") - } - close(task.wait) - } +type TransactionWriter interface { + Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error } diff --git a/internal/sqlutil/writer_dummy.go b/internal/sqlutil/writer_dummy.go new file mode 100644 index 000000000..e6ab81f68 --- /dev/null +++ b/internal/sqlutil/writer_dummy.go @@ -0,0 +1,22 @@ +package sqlutil + +import ( + "database/sql" +) + +type DummyTransactionWriter struct { +} + +func NewDummyTransactionWriter() TransactionWriter { + return &DummyTransactionWriter{} +} + +func (w *DummyTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { + if txn == nil { + return WithTransaction(db, func(txn *sql.Tx) error { + return f(txn) + }) + } else { + return f(txn) + } +} diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go new file mode 100644 index 000000000..2e3666aec --- /dev/null +++ b/internal/sqlutil/writer_exclusive.go @@ -0,0 +1,75 @@ +package sqlutil + +import ( + "database/sql" + "errors" + + "go.uber.org/atomic" +) + +// ExclusiveTransactionWriter allows queuing database writes so that you don't +// contend on database locks in, e.g. SQLite. Only one task will run +// at a time on a given ExclusiveTransactionWriter. +type ExclusiveTransactionWriter struct { + running atomic.Bool + todo chan transactionWriterTask +} + +func NewTransactionWriter() TransactionWriter { + return &ExclusiveTransactionWriter{ + todo: make(chan transactionWriterTask), + } +} + +// transactionWriterTask represents a specific task. +type transactionWriterTask struct { + db *sql.DB + txn *sql.Tx + f func(txn *sql.Tx) error + wait chan error +} + +// Do queues a task to be run by a TransactionWriter. The function +// provided will be ran within a transaction as supplied by the +// txn parameter if one is supplied, and if not, will take out a +// new transaction from the database supplied in the database +// parameter. Either way, this will block until the task is done. +func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { + if w.todo == nil { + return errors.New("not initialised") + } + if !w.running.Load() { + go w.run() + } + task := transactionWriterTask{ + db: db, + txn: txn, + f: f, + wait: make(chan error, 1), + } + w.todo <- task + return <-task.wait +} + +// run processes the tasks for a given transaction writer. Only one +// of these goroutines will run at a time. A transaction will be +// opened using the database object from the task and then this will +// be passed as a parameter to the task function. +func (w *ExclusiveTransactionWriter) run() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for task := range w.todo { + if task.txn != nil { + task.wait <- task.f(task.txn) + } else if task.db != nil { + task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { + return task.f(txn) + }) + } else { + panic("expected database or transaction but got neither") + } + close(task.wait) + } +} diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index a4d71fe13..c95790be7 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -63,7 +63,7 @@ const deleteAllDeviceKeysSQL = "" + type deviceKeysStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter upsertDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index 02b9d193e..f451d657b 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -52,7 +52,7 @@ const selectKeyChangesSQL = "" + type keyChangesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter upsertKeyChangeStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt } diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index 907966a7a..c71cc47d1 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -60,7 +60,7 @@ const selectKeyByAlgorithmSQL = "" + type oneTimeKeysStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter upsertKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt selectKeysCountStmt *sql.Stmt diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index f53f164d4..ff6ddf3da 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -62,7 +62,7 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user type mediaStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index 0158c8f7f..3be5218d5 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -57,7 +57,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( ) (err error) { updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) if err != nil { - return + return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) } succeeded := false defer func() { @@ -79,7 +79,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( } if err = u.doUpdateLatestEvents(); err != nil { - return err + return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } succeeded = true @@ -137,7 +137,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // don't need to do anything, as we've handled it already. hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID) if err != nil { - return err + return fmt.Errorf("u.updater.HasEventBeenSent: %w", err) } else if hasBeenSent { return nil } @@ -145,7 +145,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // Update the roomserver_previous_events table with references. This // is effectively tracking the structure of the DAG. if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil { - return err + return fmt.Errorf("u.updater.StorePreviousEvents: %w", err) } // Get the event reference for our new event. This will be used when @@ -156,7 +156,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // in the room. If it is then it isn't a latest event. alreadyReferenced, err := u.updater.IsReferenced(eventReference) if err != nil { - return err + return fmt.Errorf("u.updater.IsReferenced: %w", err) } // Work out what the latest events are. @@ -173,19 +173,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // Now that we know what the latest events are, it's time to get the // latest state. if err = u.latestState(); err != nil { - return err + return fmt.Errorf("u.latestState: %w", err) } // If we need to generate any output events then here's where we do it. // TODO: Move this! updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added) if err != nil { - return err + return fmt.Errorf("u.api.updateMemberships: %w", err) } update, err := u.makeOutputNewRoomEvent() if err != nil { - return err + return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) } updates = append(updates, *update) @@ -198,14 +198,18 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // necessary bookkeeping we'll keep the event sending synchronous for now. if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { - return err + return fmt.Errorf("u.api.WriteOutputEvents: %w", err) } if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { - return err + return fmt.Errorf("u.updater.SetLatestEvents: %w", err) } - return u.updater.MarkEventAsSent(u.stateAtEvent.EventNID) + if err = u.updater.MarkEventAsSent(u.stateAtEvent.EventNID); err != nil { + return fmt.Errorf("u.updater.MarkEventAsSent: %w", err) + } + + return nil } func (u *latestEventsUpdater) latestState() error { @@ -225,7 +229,7 @@ func (u *latestEventsUpdater) latestState() error { u.ctx, u.roomNID, latestStateAtEvents, ) if err != nil { - return err + return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) } // If we are overwriting the state then we should make sure that we @@ -244,7 +248,7 @@ func (u *latestEventsUpdater) latestState() error { u.ctx, u.oldStateNID, u.newStateNID, ) if err != nil { - return err + return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err) } // Also work out the state before the event removes and the event @@ -252,7 +256,11 @@ func (u *latestEventsUpdater) latestState() error { u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots( u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, ) - return err + if err != nil { + return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err) + } + + return nil } func calculateLatest( diff --git a/roomserver/state/state.go b/roomserver/state/state.go index d5be4a901..b9ad4a504 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -558,7 +558,11 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // 2) There weren't any prev_events for this event so the state is // empty. metrics.algorithm = "empty_state" - return metrics.stop(v.db.AddState(ctx, roomNID, nil, nil)) + stateNID, err := v.db.AddState(ctx, roomNID, nil, nil) + if err != nil { + err = fmt.Errorf("v.db.AddState: %w", err) + } + return metrics.stop(stateNID, err) } if len(prevStates) == 1 { @@ -578,22 +582,30 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( ) if err != nil { metrics.algorithm = "_load_state_blocks" - return metrics.stop(0, err) + return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err)) } stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs if len(stateBlockNIDs) < maxStateBlockNIDs { // 4) The number of state data blocks is small enough that we can just // add the state event as a block of size one to the end of the blocks. metrics.algorithm = "single_delta" - return metrics.stop(v.db.AddState( + stateNID, err := v.db.AddState( ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, - )) + ) + if err != nil { + err = fmt.Errorf("v.db.AddState: %w", err) + } + return metrics.stop(stateNID, err) } // If there are too many deltas then we need to calculate the full state // So fall through to calculateAndStoreStateAfterManyEvents } - return v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) + stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) + if err != nil { + return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) + } + return stateNID, nil } // maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 52ff479ba..0b7ed225a 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -98,6 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { } d.Database = shared.Database{ DB: db, + Writer: sqlutil.NewDummyTransactionWriter(), EventTypesTable: eventTypes, EventStateKeysTable: eventStateKeys, EventJSONTable: eventJSON, diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go index 21b168a4f..e9a0f6982 100644 --- a/roomserver/storage/shared/latest_events_updater.go +++ b/roomserver/storage/shared/latest_events_updater.go @@ -3,6 +3,7 @@ package shared import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -65,12 +66,14 @@ func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { // StorePreviousEvents implements types.RoomRecentEventsUpdater func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return err + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } } - } - return nil + return nil + }) } // IsReferenced implements types.RoomRecentEventsUpdater @@ -82,7 +85,7 @@ func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.Even if err == sql.ErrNoRows { return false, nil } - return false, err + return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) } // SetLatestEvents implements types.RoomRecentEventsUpdater @@ -94,7 +97,12 @@ func (u *LatestEventsUpdater) SetLatestEvents( for i := range latest { eventNIDs[i] = latest[i].EventNID } - return u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { + return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) + } + return nil + }) } // HasEventBeenSent implements types.RoomRecentEventsUpdater @@ -104,7 +112,9 @@ func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, e // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, u.txn, eventNID) + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) + }) } func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 5955844f9..329813bfc 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -3,6 +3,7 @@ package shared import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -41,9 +42,14 @@ func (d *Database) membershipUpdaterTxn( targetUserNID types.EventStateKeyNID, targetLocal bool, ) (*MembershipUpdater, error) { - - if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { - return nil, err + err := d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + return fmt.Errorf("d.MembershipTable.InsertMembership: %w", err) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("u.d.Writer.Do: %w", err) } membership, err := d.MembershipTable.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) @@ -75,19 +81,19 @@ func (u *MembershipUpdater) IsLeave() bool { func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) if err != nil { - return false, err + return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } inserted, err := u.d.InvitesTable.InsertInviteEvent( u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), ) if err != nil { - return false, err + return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) } if u.membership != tables.MembershipStateInvite { if err = u.d.MembershipTable.UpdateMembership( u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, ); err != nil { - return false, err + return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } return inserted, nil @@ -99,7 +105,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } // If this is a join event update, there is no invite to update @@ -108,14 +114,14 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) } } // Look up the NID of the new join event nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.EventNIDs: %w", err) } if u.membership != tables.MembershipStateJoin || isUpdate { @@ -123,7 +129,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], ); err != nil { - return nil, err + return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -134,19 +140,19 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err) } // Look up the NID of the new leave event nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.EventNIDs: %w", err) } if u.membership != tables.MembershipStateLeaveOrBan { @@ -154,7 +160,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], ); err != nil { - return nil, err + return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } return inviteEventIDs, nil diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 00179e336..45020d551 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -27,6 +27,7 @@ const redactionsArePermanent = false type Database struct { DB *sql.DB + Writer sqlutil.TransactionWriter EventsTable tables.Events EventJSONTable tables.EventJSON EventTypesTable tables.EventTypes @@ -83,20 +84,23 @@ func (d *Database) AddState( stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if len(state) > 0 { var stateBlockNID types.StateBlockNID stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) if err != nil { - return err + return fmt.Errorf("d.StateBlockTable.BulkInsertStateData: %w", err) } stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) } stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs) - return err + if err != nil { + return fmt.Errorf("d.StateSnapshotTable.InsertState: %w", err) + } + return nil }) if err != nil { - return 0, err + return 0, fmt.Errorf("d.Writer.Do: %w", err) } return } @@ -110,7 +114,9 @@ func (d *Database) EventNIDs( func (d *Database) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) + }) } func (d *Database) StateAtEventIDs( @@ -221,7 +227,9 @@ func (d *Database) GetRoomVersionForRoomNID( } func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) + }) } func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { @@ -239,15 +247,21 @@ func (d *Database) GetCreatorIDForAlias( } func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) + }) } func (d *Database) GetMembership( ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, ) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) + var requestSenderUserNID types.EventStateKeyNID + err = d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + requestSenderUserNID, err = d.assignStateKeyNID(ctx, nil, requestSenderUserID) + return err + }) if err != nil { - return + return 0, false, fmt.Errorf("d.assignStateKeyNID: %w", err) } senderMembershipEventNID, senderMembership, err := @@ -350,6 +364,7 @@ func (d *Database) GetLatestEventsForUpdate( return NewLatestEventsUpdater(ctx, d, txn, roomNID) } +// nolint:gocyclo func (d *Database) StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, @@ -365,10 +380,10 @@ func (d *Database) StoreEvent( err error ) - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if txnAndSessionID != nil { if err = d.TransactionsTable.InsertTransaction( - ctx, txn, txnAndSessionID.TransactionID, + ctx, nil, txnAndSessionID.TransactionID, txnAndSessionID.SessionID, event.Sender(), event.EventID(), ); err != nil { return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err) @@ -433,7 +448,7 @@ func (d *Database) StoreEvent( return nil }) if err != nil { - return 0, types.StateAtEvent{}, nil, "", err + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) } return roomNID, types.StateAtEvent{ @@ -449,7 +464,9 @@ func (d *Database) StoreEvent( } func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error { - return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish) + }) } func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index e8118ad76..3cd44b1dc 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -49,15 +49,13 @@ const bulkSelectEventJSONSQL = ` type eventJSONStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt *sql.Stmt } -func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventJSON, error) { +func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { s := &eventJSONStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventJSONSchema) if err != nil { @@ -72,10 +70,8 @@ func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) - return err - }) + _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + return err } func (s *eventJSONStatements) BulkSelectEventJSON( diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index c8ad052bf..345df8c62 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -64,17 +64,15 @@ const bulkSelectEventStateKeyNIDSQL = ` type eventStateKeyStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventStateKeyNIDStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyStmt *sql.Stmt } -func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventStateKeys, error) { +func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventStateKeysSchema) if err != nil { @@ -91,19 +89,15 @@ func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { - var eventStateKeyNID int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) - res, err := insertStmt.ExecContext(ctx, eventStateKey) - if err != nil { - return err - } - eventStateKeyNID, err = res.LastInsertId() - if err != nil { - return err - } - return nil - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) + res, err := insertStmt.ExecContext(ctx, eventStateKey) + if err != nil { + return 0, err + } + eventStateKeyNID, err := res.LastInsertId() + if err != nil { + return 0, err + } return types.EventStateKeyNID(eventStateKeyNID), err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 4a645789d..26e2bf843 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "strings" "github.com/matrix-org/dendrite/internal" @@ -78,17 +79,15 @@ const bulkSelectEventTypeNIDSQL = ` type eventTypeStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventTypeNIDStmt *sql.Stmt insertEventTypeNIDResultStmt *sql.Stmt selectEventTypeNIDStmt *sql.Stmt bulkSelectEventTypeNIDStmt *sql.Stmt } -func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventTypes, error) { +func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventTypesSchema) if err != nil { @@ -104,18 +103,18 @@ func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta } func (s *eventTypeStatements) InsertEventTypeNID( - ctx context.Context, tx *sql.Tx, eventType string, + ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error { - insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt) - resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt) - _, err := insertStmt.ExecContext(ctx, eventType) - if err != nil { - return err - } - return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt) + resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt) + _, err := insertStmt.ExecContext(ctx, eventType) + if err != nil { + return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + } + if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil { + return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err) + } return types.EventTypeNID(eventTypeNID), err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 0e39755cb..26ea1d415 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -99,7 +99,6 @@ const selectRoomNIDForEventNIDSQL = "" + type eventStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt @@ -115,10 +114,9 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Events, error) { +func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventsSchema) if err != nil { @@ -155,22 +153,19 @@ func (s *eventStatements) InsertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID var eventNID int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - result, err := insertStmt.ExecContext( - ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, - ) - if err != nil { - return err - } - modified, err := result.RowsAffected() - if modified == 0 && err == nil { - return sql.ErrNoRows - } - eventNID, err = result.LastInsertId() - return err - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + result, err := insertStmt.ExecContext( + ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + ) + if err != nil { + return 0, 0, err + } + modified, err := result.RowsAffected() + if modified == 0 && err == nil { + return 0, 0, sql.ErrNoRows + } + eventNID, err = result.LastInsertId() return types.EventNID(eventNID), 0, err } @@ -286,11 +281,8 @@ func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) UpdateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt) - _, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) - return err - }) + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + return err } func (s *eventStatements) SelectEventSentToOutput( @@ -302,11 +294,9 @@ func (s *eventStatements) SelectEventSentToOutput( } func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) - _, err := updateStmt.ExecContext(ctx, int64(eventNID)) - return err - }) + updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := updateStmt.ExecContext(ctx, int64(eventNID)) + return err } func (s *eventStatements) SelectEventID( @@ -334,7 +324,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) if err != nil { - return nil, err + return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err) } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") results := make([]types.StateAtEventAndReference, len(eventNIDs)) @@ -481,7 +471,7 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result) if err != nil { - return 0, err + return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) } return result, nil } diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 1305f4a8a..327be6a03 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -64,17 +64,15 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni type inviteStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertInviteEventStmt *sql.Stmt selectInviteActiveForUserInRoomStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt selectInvitesAboutToRetireStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Invites, error) { +func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { s := &inviteStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(inviteSchema) if err != nil { @@ -96,20 +94,17 @@ func (s *inviteStatements) InsertInviteEvent( inviteEventJSON []byte, ) (bool, error) { var count int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) - result, err := stmt.ExecContext( - ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, - ) - if err != nil { - return err - } - count, err = result.RowsAffected() - if err != nil { - return err - } - return nil - }) + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + result, err := stmt.ExecContext( + ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, + ) + if err != nil { + return false, err + } + count, err = result.RowsAffected() + if err != nil { + return false, err + } return count != 0, err } @@ -117,26 +112,23 @@ func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - // gather all the event IDs we will retire - stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) - rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) - if err != nil { - return err + // gather all the event IDs we will retire + stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) + rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed") + for rows.Next() { + var inviteEventID string + if err = rows.Scan(&inviteEventID); err != nil { + return } - defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed") - for rows.Next() { - var inviteEventID string - if err = rows.Scan(&inviteEventID); err != nil { - return err - } - eventIDs = append(eventIDs, inviteEventID) - } - // now retire the invites - stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) - _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) - return err - }) + eventIDs = append(eventIDs, inviteEventID) + } + // now retire the invites + stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) + _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) return } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 7b69cee32..b3ee69c00 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -77,7 +77,6 @@ const updateMembershipSQL = "" + type membershipStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt @@ -88,10 +87,9 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func NewSqliteMembershipTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Membership, error) { +func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { s := &membershipStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(membershipSchema) if err != nil { @@ -115,11 +113,9 @@ func (s *membershipStatements) InsertMembership( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) + return err } func (s *membershipStatements) SelectMembershipForUpdate( @@ -201,11 +197,9 @@ func (s *membershipStatements) UpdateMembership( senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) - _, err := stmt.ExecContext( - ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) + _, err := stmt.ExecContext( + ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, + ) + return err } diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index ff804861c..d28a42c69 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -54,15 +54,13 @@ const selectPreviousEventExistsSQL = ` type previousEventStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertPreviousEventStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt } -func NewSqlitePrevEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.PreviousEvents, error) { +func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { s := &previousEventStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(previousEventSchema) if err != nil { @@ -82,13 +80,11 @@ func (s *previousEventStatements) InsertPreviousEvent( previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) - _, err := stmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) + _, err := stmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ) + return err } // Check if the event reference exists diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index a4a47aec9..1d6ccd561 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -45,16 +44,14 @@ const selectPublishedSQL = "" + type publishedStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter upsertPublishedStmt *sql.Stmt selectAllPublishedStmt *sql.Stmt selectPublishedStmt *sql.Stmt } -func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Published, error) { +func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { s := &publishedStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(publishedSchema) if err != nil { @@ -69,12 +66,9 @@ func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab func (s *publishedStatements) UpsertRoomPublished( ctx context.Context, roomID string, published bool, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) - _, err := stmt.ExecContext(ctx, roomID, published) - return err - }) +) error { + _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published) + return err } func (s *publishedStatements) SelectPublishedFromRoomID( diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go index ad900a4ec..a2179357c 100644 --- a/roomserver/storage/sqlite3/redactions_table.go +++ b/roomserver/storage/sqlite3/redactions_table.go @@ -53,17 +53,15 @@ const markRedactionValidatedSQL = "" + type redactionStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertRedactionStmt *sql.Stmt selectRedactionInfoByRedactionEventIDStmt *sql.Stmt selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt markRedactionValidatedStmt *sql.Stmt } -func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Redactions, error) { +func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { s := &redactionStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(redactionsSchema) if err != nil { @@ -81,11 +79,9 @@ func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta func (s *redactionStatements) InsertRedaction( ctx context.Context, txn *sql.Tx, info tables.RedactionInfo, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) - _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) + _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) + return err } func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( @@ -121,9 +117,7 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted( func (s *redactionStatements) MarkRedactionValidated( ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) - _, err := stmt.ExecContext(ctx, redactionEventID, validated) - return err - }) + stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) + _, err := stmt.ExecContext(ctx, redactionEventID, validated) + return err } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index deba3ff55..a16e97aa5 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -57,7 +56,6 @@ const deleteRoomAliasSQL = ` type roomAliasesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertRoomAliasStmt *sql.Stmt selectRoomIDFromAliasStmt *sql.Stmt selectAliasesFromRoomIDStmt *sql.Stmt @@ -65,10 +63,9 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.RoomAliases, error) { +func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { s := &roomAliasesStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(roomAliasesSchema) if err != nil { @@ -85,12 +82,9 @@ func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (t func (s *roomAliasesStatements) InsertRoomAlias( ctx context.Context, alias string, roomID string, creatorUserID string, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt) - _, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID) - return err - }) +) error { + _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) + return err } func (s *roomAliasesStatements) SelectRoomIDFromAlias( @@ -138,10 +132,7 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) DeleteRoomAlias( ctx context.Context, alias string, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt) - _, err := stmt.ExecContext(ctx, alias) - return err - }) +) error { + _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias) + return err } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 8bbec5080..6541cc0cb 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -66,7 +66,6 @@ const selectRoomVersionForRoomNIDSQL = "" + type roomStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt @@ -76,10 +75,9 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func NewSqliteRoomsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Rooms, error) { +func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { s := &roomStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(roomsSchema) if err != nil { @@ -100,20 +98,14 @@ func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (roomNID types.RoomNID, err error) { - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) - _, err = insertStmt.ExecContext(ctx, roomID, roomVersion) - if err != nil { - return fmt.Errorf("insertStmt.ExecContext: %w", err) - } - roomNID, err = s.SelectRoomNID(ctx, txn, roomID) - if err != nil { - return fmt.Errorf("s.SelectRoomNID: %w", err) - } - return nil - }) + insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) + _, err = insertStmt.ExecContext(ctx, roomID, roomVersion) if err != nil { - return types.RoomNID(0), err + return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + } + roomNID, err = s.SelectRoomNID(ctx, txn, roomID) + if err != nil { + return 0, fmt.Errorf("s.SelectRoomNID: %w", err) } return } @@ -170,17 +162,15 @@ func (s *roomStatements) UpdateLatestEventNIDs( lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) - _, err := stmt.ExecContext( - ctx, - eventNIDsAsArray(eventNIDs), - int64(lastEventSentNID), - int64(stateSnapshotNID), - roomNID, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) + _, err := stmt.ExecContext( + ctx, + eventNIDsAsArray(eventNIDs), + int64(lastEventSentNID), + int64(stateSnapshotNID), + roomNID, + ) + return err } func (s *roomStatements) SelectRoomVersionForRoomID( diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 3e28e450b..8033903f5 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -74,17 +74,15 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" + type stateBlockStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertStateDataStmt *sql.Stmt selectNextStateBlockNIDStmt *sql.Stmt bulkSelectStateBlockEntriesStmt *sql.Stmt bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func NewSqliteStateBlockTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateBlock, error) { +func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(stateDataSchema) if err != nil { @@ -107,25 +105,22 @@ func (s *stateBlockStatements) BulkInsertStateData( return 0, nil } var stateBlockNID types.StateBlockNID - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + if err != nil { + return 0, err + } + for _, entry := range entries { + _, err = txn.Stmt(s.insertStateDataStmt).ExecContext( + ctx, + int64(stateBlockNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), + ) if err != nil { - return err + return 0, err } - for _, entry := range entries { - _, err := txn.Stmt(s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return err - } - } - return nil - }) + } return stateBlockNID, err } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 799904ff6..392c2a671 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -50,15 +50,13 @@ const bulkSelectStateBlockNIDsSQL = "" + type stateSnapshotStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertStateStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt } -func NewSqliteStateSnapshotTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateSnapshot, error) { +func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(stateSnapshotSchema) if err != nil { @@ -78,19 +76,16 @@ func (s *stateSnapshotStatements) InsertState( if err != nil { return } - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := txn.Stmt(s.insertStateStmt) - res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) - if err != nil { - return err - } - lastRowID, err := res.LastInsertId() - if err != nil { - return err - } - stateNID = types.StateSnapshotNID(lastRowID) - return nil - }) + insertStmt := txn.Stmt(s.insertStateStmt) + res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) + if err != nil { + return 0, err + } + lastRowID, err := res.LastInsertId() + if err != nil { + return 0, err + } + stateNID = types.StateSnapshotNID(lastRowID) return } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 724316373..8e3af6b7a 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -41,6 +41,7 @@ type Database struct { invites tables.Invites membership tables.Membership db *sql.DB + writer sqlutil.TransactionWriter } // Open a sqlite database. @@ -51,7 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - writer := sqlutil.NewTransactionWriter() + d.writer = sqlutil.NewTransactionWriter() //d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA read_uncommitted = true;") @@ -61,64 +62,65 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { // which it will never obtain. d.db.SetMaxOpenConns(20) - d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db, writer) + d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) if err != nil { return nil, err } - d.eventTypes, err = NewSqliteEventTypesTable(d.db, writer) + d.eventTypes, err = NewSqliteEventTypesTable(d.db) if err != nil { return nil, err } - d.eventJSON, err = NewSqliteEventJSONTable(d.db, writer) + d.eventJSON, err = NewSqliteEventJSONTable(d.db) if err != nil { return nil, err } - d.events, err = NewSqliteEventsTable(d.db, writer) + d.events, err = NewSqliteEventsTable(d.db) if err != nil { return nil, err } - d.rooms, err = NewSqliteRoomsTable(d.db, writer) + d.rooms, err = NewSqliteRoomsTable(d.db) if err != nil { return nil, err } - d.transactions, err = NewSqliteTransactionsTable(d.db, writer) + d.transactions, err = NewSqliteTransactionsTable(d.db) if err != nil { return nil, err } - stateBlock, err := NewSqliteStateBlockTable(d.db, writer) + stateBlock, err := NewSqliteStateBlockTable(d.db) if err != nil { return nil, err } - stateSnapshot, err := NewSqliteStateSnapshotTable(d.db, writer) + stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) if err != nil { return nil, err } - d.prevEvents, err = NewSqlitePrevEventsTable(d.db, writer) + d.prevEvents, err = NewSqlitePrevEventsTable(d.db) if err != nil { return nil, err } - roomAliases, err := NewSqliteRoomAliasesTable(d.db, writer) + roomAliases, err := NewSqliteRoomAliasesTable(d.db) if err != nil { return nil, err } - d.invites, err = NewSqliteInvitesTable(d.db, writer) + d.invites, err = NewSqliteInvitesTable(d.db) if err != nil { return nil, err } - d.membership, err = NewSqliteMembershipTable(d.db, writer) + d.membership, err = NewSqliteMembershipTable(d.db) if err != nil { return nil, err } - published, err := NewSqlitePublishedTable(d.db, writer) + published, err := NewSqlitePublishedTable(d.db) if err != nil { return nil, err } - redactions, err := NewSqliteRedactionsTable(d.db, writer) + redactions, err := NewSqliteRedactionsTable(d.db) if err != nil { return nil, err } d.Database = shared.Database{ DB: d.db, + Writer: sqlutil.NewTransactionWriter(), EventsTable: d.events, EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 65c18a8a9..029122c5e 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -45,15 +45,13 @@ const selectTransactionEventIDSQL = ` type transactionStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertTransactionStmt *sql.Stmt selectTransactionEventIDStmt *sql.Stmt } -func NewSqliteTransactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Transactions, error) { +func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { s := &transactionStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(transactionsSchema) if err != nil { @@ -72,14 +70,12 @@ func (s *transactionStatements) InsertTransaction( sessionID int64, userID string, eventID string, -) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) - _, err := stmt.ExecContext( - ctx, transactionID, sessionID, userID, eventID, - ) - return err - }) +) error { + stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) + _, err := stmt.ExecContext( + ctx, transactionID, sessionID, userID, eventID, + ) + return err } func (s *transactionStatements) SelectTransactionEventID( diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index 423292a54..b829eae74 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -63,7 +63,7 @@ const upsertServerKeysSQL = "" + type serverKeyStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter bulkSelectServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index dd5b838ce..fdbf6758d 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -45,7 +45,7 @@ type Database struct { BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice Filter tables.Filter - SendToDeviceWriter *sqlutil.TransactionWriter + SendToDeviceWriter sqlutil.TransactionWriter EDUCache *cache.EDUCache } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 609cef141..248ec9267 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -51,7 +51,7 @@ const selectMaxAccountDataIDSQL = "" + type accountDataStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 1aeb041f4..d96f2fe57 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -49,7 +49,7 @@ const deleteBackwardExtremitySQL = "" + type backwardExtremitiesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 6edc99aa0..77a21543f 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -85,7 +85,7 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 3e8a46551..338b0b500 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -52,7 +52,7 @@ const insertFilterSQL = "" + type filterStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter selectFilterStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt insertFilterStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 19e7a7c68..0bbd79f77 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -59,7 +59,7 @@ const selectMaxInviteIDSQL = "" + type inviteEventsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 12b4dbabe..0d1546507 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -105,7 +105,7 @@ const selectStateInRangeSQL = "" + type outputRoomEventsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 2e71e8f33..5c4ab005f 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -67,7 +67,7 @@ const selectMaxPositionInTopologySQL = "" + type outputRoomEventsTopologyStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 88b319fb3..53786589c 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -73,7 +73,7 @@ const deleteSendToDeviceMessagesSQL = ` type sendToDeviceStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertSendToDeviceMessageStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index cf3eed5ba..1971e7f3b 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -28,7 +28,7 @@ const selectStreamIDStmt = "" + type streamIDStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter increaseStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt } diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index cb54412ab..9b40e6579 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -51,7 +51,7 @@ const selectAccountDataByTypeSQL = "" + type accountDataStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertAccountDataStmt *sql.Stmt selectAccountDataStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 27c3d845a..586bcab91 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" + type accountsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertAccountStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index d4c404ca3..cd35d2982 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" + type profilesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 0104e8346..3000d7c43 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -54,7 +54,7 @@ const deleteThreePIDSQL = "" + type threepidStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter selectLocalpartForThreePIDStmt *sql.Stmt selectThreePIDsForLocalpartStmt *sql.Stmt insertThreePIDStmt *sql.Stmt diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index 9b535aab9..962e63b03 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" + type devicesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt From f5edfb9659f97b564b490cbbc1118b380d6244cd Mon Sep 17 00:00:00 2001 From: anandv96 <60289989+anandv96@users.noreply.github.com> Date: Thu, 20 Aug 2020 12:57:43 +0530 Subject: [PATCH 4/6] #903: Client API: mutex on (user_id, room_id) (#1286) * Client API: mutex on (user_id, room_id) * Client API: mutex on (user_id, room_id) Changed variable name used for the mutexes map Changed the place where the mutex is locked Changed unlock to a defered call instead of manually calling it --- clientapi/routing/sendevent.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index e0cd7eb5d..9cf517cff 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -16,6 +16,7 @@ package routing import ( "net/http" + "sync" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -35,6 +36,10 @@ type sendEventResponse struct { EventID string `json:"event_id"` } +var ( + userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents +) + // SendEvent implements: // /rooms/{roomID}/send/{eventType} // /rooms/{roomID}/send/{eventType}/{txnID} @@ -63,6 +68,13 @@ func SendEvent( } } + // create a mutex for the specific user in the specific room + // this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order + userID := device.UserID + mutex, _ := userRoomSendMutexes.LoadOrStore(roomID+userID, &sync.Mutex{}) + mutex.(*sync.Mutex).Lock() + defer mutex.(*sync.Mutex).Unlock() + e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI) if resErr != nil { return *resErr From 5ad47d3b3dc79fc8c7c9255728710906b118e2d8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 20 Aug 2020 09:24:52 +0100 Subject: [PATCH 5/6] Fix more roomserver transactions/locks (#1287) * Fix transaction to InsertTransaction * Remove unnecessary txn, add txns around setting up updaters --- roomserver/storage/shared/storage.go | 45 ++++++++++++++++------------ 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 45020d551..766d4f205 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -179,22 +179,19 @@ func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (ro func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, ) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { - var eventNIDs []types.EventNID - eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, txn, roomNID) - if err != nil { - return err - } - references, err = d.EventsTable.BulkSelectEventReference(ctx, txn, eventNIDs) - if err != nil { - return err - } - depth, err = d.EventsTable.SelectMaxEventDepth(ctx, txn, eventNIDs) - if err != nil { - return err - } - return nil - }) + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID) + if err != nil { + return + } + references, err = d.EventsTable.BulkSelectEventReference(ctx, nil, eventNIDs) + if err != nil { + return + } + depth, err = d.EventsTable.SelectMaxEventDepth(ctx, nil, eventNIDs) + if err != nil { + return + } return } @@ -351,7 +348,12 @@ func (d *Database) MembershipUpdater( if err != nil { return nil, err } - return NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion) + var updater *MembershipUpdater + _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + updater, err = NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion) + return nil + }) + return updater, err } func (d *Database) GetLatestEventsForUpdate( @@ -361,7 +363,12 @@ func (d *Database) GetLatestEventsForUpdate( if err != nil { return nil, err } - return NewLatestEventsUpdater(ctx, d, txn, roomNID) + var updater *LatestEventsUpdater + _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + updater, err = NewLatestEventsUpdater(ctx, d, txn, roomNID) + return nil + }) + return updater, err } // nolint:gocyclo @@ -383,7 +390,7 @@ func (d *Database) StoreEvent( err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if txnAndSessionID != nil { if err = d.TransactionsTable.InsertTransaction( - ctx, nil, txnAndSessionID.TransactionID, + ctx, txn, txnAndSessionID.TransactionID, txnAndSessionID.SessionID, event.Sender(), event.EventID(), ); err != nil { return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err) From 0fea056db43c11c5de97fd96bcc60703ca1b4c08 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 20 Aug 2020 14:58:53 +0100 Subject: [PATCH 6/6] Change backoff behaviour so that Failure returns planned end time (#1288) --- federationsender/statistics/statistics.go | 47 ++++++++++++++++--- .../statistics/statistics_test.go | 14 +++++- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/federationsender/statistics/statistics.go b/federationsender/statistics/statistics.go index 0dd8da200..a574ceffb 100644 --- a/federationsender/statistics/statistics.go +++ b/federationsender/statistics/statistics.go @@ -66,10 +66,16 @@ type ServerStatistics struct { serverName gomatrixserverlib.ServerName // blacklisted atomic.Bool // is the node blacklisted backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends backoffCount atomic.Uint32 // number of times BackoffDuration has been called successCounter atomic.Uint32 // how many times have we succeeded? } +// duration returns how long the next backoff interval should be. +func (s *ServerStatistics) duration(count uint32) time.Duration { + return time.Second * time.Duration(math.Exp2(float64(count))) +} + // Success updates the server statistics with a new successful // attempt, which increases the sent counter and resets the idle and // failure counters. If a host was blacklisted at this point then @@ -88,11 +94,36 @@ func (s *ServerStatistics) Success() { // Failure marks a failure and starts backing off if needed. // The next call to BackoffIfRequired will do the right thing -// after this. -func (s *ServerStatistics) Failure() { +// after this. It will return the time that the current failure +// will result in backoff waiting until, and a bool signalling +// whether we have blacklisted and therefore to give up. +func (s *ServerStatistics) Failure() (time.Time, bool) { + // If we aren't already backing off, this call will start + // a new backoff period. Reset the counter to 0 so that + // we backoff only for short periods of time to start with. if s.backoffStarted.CAS(false, true) { s.backoffCount.Store(0) } + + // Check if we have blacklisted this node. + if s.blacklisted.Load() { + return time.Now(), true + } + + // If we're already backing off and we haven't yet surpassed + // the deadline then return that. Repeated calls to Failure + // within a single backoff interval will have no side effects. + if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) { + return until, false + } + + // We're either backing off and have passed the deadline, or + // we aren't backing off, so work out what the next interval + // will be. + count := s.backoffCount.Load() + until := time.Now().Add(s.duration(count)) + s.backoffUntil.Store(until) + return until, false } // BackoffIfRequired will block for as long as the current @@ -102,11 +133,8 @@ func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt < return 0, false } - // Work out how many times we've backed off so far. - count := s.backoffCount.Inc() - duration := time.Second * time.Duration(math.Exp2(float64(count))) - // Work out if we should be blacklisting at this point. + count := s.backoffCount.Inc() if count >= s.statistics.FailuresUntilBlacklist { // We've exceeded the maximum amount of times we're willing // to back off, which is probably in the region of hours by @@ -118,9 +146,14 @@ func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt < logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) } } - return duration, true + return 0, true } + // Work out when we should wait until. + duration := s.duration(count) + until := time.Now().Add(duration) + s.backoffUntil.Store(until) + // Notify the destination queue that we're backing off now. backingOff.Store(true) defer backingOff.Store(false) diff --git a/federationsender/statistics/statistics_test.go b/federationsender/statistics/statistics_test.go index 9050662ec..7e083de68 100644 --- a/federationsender/statistics/statistics_test.go +++ b/federationsender/statistics/statistics_test.go @@ -10,7 +10,7 @@ import ( func TestBackoff(t *testing.T) { stats := Statistics{ - FailuresUntilBlacklist: 5, + FailuresUntilBlacklist: 7, } server := ServerStatistics{ statistics: &stats, @@ -41,10 +41,20 @@ func TestBackoff(t *testing.T) { // Get the duration. duration, blacklist := server.BackoffIfRequired(backingOff, interrupt) + // Register another failure for good measure. This should have no + // side effects since a backoff is already in progress. If it does + // then we'll fail. + until, blacklisted := server.Failure() + if time.Until(until) > duration { + t.Fatal("Failure produced unexpected side effect when it shouldn't have") + } + // Check if we should be blacklisted by now. - if i > stats.FailuresUntilBlacklist { + if i >= stats.FailuresUntilBlacklist { if !blacklist { t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) + } else if blacklist != blacklisted { + t.Fatalf("BackoffIfRequired and Failure returned different blacklist values") } else { t.Logf("Backoff %d is blacklisted as expected", i) continue