Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input

This commit is contained in:
Neil Alexander 2022-02-01 12:52:37 +00:00
parent 893aa3b141
commit f2c0bb165e
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
28 changed files with 490 additions and 294 deletions

View file

@ -19,12 +19,14 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
"github.com/Arceliar/phony" "github.com/Arceliar/phony"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
fedapi "github.com/matrix-org/dendrite/federationapi/api" fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/acls"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
@ -101,7 +103,7 @@ func (r *Inputer) Start() error {
_ = msg.InProgress() // resets the acknowledgement wait timer _ = msg.InProgress() // resets the acknowledgement wait timer
defer eventsInProgress.Delete(index) defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil { if err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent); err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err) sentry.CaptureException(err)
} }
@ -131,6 +133,28 @@ func (r *Inputer) Start() error {
return err return err
} }
func (r *Inputer) processRoomEventUsingUpdater(
ctx context.Context,
roomID string,
inputRoomEvent *api.InputRoomEvent,
) error {
roomInfo, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err)
}
updater, err := r.DB.GetRoomUpdater(ctx, *roomInfo)
if err != nil {
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
}
succeeded := false
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
if err = r.processRoomEvent(ctx, updater, inputRoomEvent); err != nil {
return fmt.Errorf("r.processRoomEvent: %w", err)
}
succeeded = true
return nil
}
// InputRoomEvents implements api.RoomserverInternalAPI // InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents( func (r *Inputer) InputRoomEvents(
ctx context.Context, ctx context.Context,
@ -178,7 +202,7 @@ func (r *Inputer) InputRoomEvents(
worker.Act(nil, func() { worker.Act(nil, func() {
defer eventsInProgress.Delete(index) defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
err := r.processRoomEvent(ctx, &inputRoomEvent) err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent)
if err != nil { if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err) sentry.CaptureException(err)

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -67,6 +68,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
// nolint:gocyclo // nolint:gocyclo
func (r *Inputer) processRoomEvent( func (r *Inputer) processRoomEvent(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
input *api.InputRoomEvent, input *api.InputRoomEvent,
) (err error) { ) (err error) {
select { select {
@ -107,7 +109,7 @@ func (r *Inputer) processRoomEvent(
// if we have already got this event then do not process it again, if the input kind is an outlier. // if we have already got this event then do not process it again, if the input kind is an outlier.
// Outliers contain no extra information which may warrant a re-processing. // Outliers contain no extra information which may warrant a re-processing.
if input.Kind == api.KindOutlier { if input.Kind == api.KindOutlier {
evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()})
if err2 == nil && len(evs) == 1 { if err2 == nil && len(evs) == 1 {
// check hash matches if we're on early room versions where the event ID was a random string // check hash matches if we're on early room versions where the event ID was a random string
idFormat, err2 := headered.RoomVersion.EventIDFormat() idFormat, err2 := headered.RoomVersion.EventIDFormat()
@ -176,7 +178,7 @@ func (r *Inputer) processRoomEvent(
isRejected := false isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{} knownEvents := map[string]*types.Event{}
if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { if err = r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err) return fmt.Errorf("r.fetchAuthEvents: %w", err)
} }
@ -227,7 +229,7 @@ func (r *Inputer) processRoomEvent(
origin: input.Origin, origin: input.Origin,
inputer: r, inputer: r,
queryer: r.Queryer, queryer: r.Queryer,
db: r.DB, db: updater,
federation: r.FSAPI, federation: r.FSAPI,
keys: r.KeyRing, keys: r.KeyRing,
roomsMu: internal.NewMutexByRoom(), roomsMu: internal.NewMutexByRoom(),
@ -248,9 +250,9 @@ func (r *Inputer) processRoomEvent(
} }
// Store the event. // Store the event.
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
// if storing this event results in it being redacted then do so. // if storing this event results in it being redacted then do so.
@ -271,18 +273,18 @@ func (r *Inputer) processRoomEvent(
return nil return nil
} }
roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) roomInfo, err := updater.RoomInfo(ctx, event.RoomID())
if err != nil { if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err) return fmt.Errorf("updater.RoomInfo: %w", err)
} }
if roomInfo == nil { if roomInfo == nil {
return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
} }
if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet. // We haven't calculated a state for this event yet.
// Lets calculate one. // Lets calculate one.
err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("r.calculateAndSetState: %w", err) return fmt.Errorf("r.calculateAndSetState: %w", err)
} }
@ -301,6 +303,7 @@ func (r *Inputer) processRoomEvent(
case api.KindNew: case api.KindNew:
if err = r.updateLatestEvents( if err = r.updateLatestEvents(
ctx, // context ctx, // context
updater, // room updater
roomInfo, // room info for the room being updated roomInfo, // room info for the room being updated
stateAtEvent, // state at event (below) stateAtEvent, // state at event (below)
event, // event event, // event
@ -358,6 +361,7 @@ func (r *Inputer) processRoomEvent(
// they are now in the database. // they are now in the database.
func (r *Inputer) fetchAuthEvents( func (r *Inputer) fetchAuthEvents(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
logger *logrus.Entry, logger *logrus.Entry,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
auth *gomatrixserverlib.AuthEvents, auth *gomatrixserverlib.AuthEvents,
@ -375,7 +379,7 @@ func (r *Inputer) fetchAuthEvents(
} }
for _, authEventID := range authEventIDs { for _, authEventID := range authEventIDs {
authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{} unknown[authEventID] = struct{}{}
continue continue
@ -454,9 +458,9 @@ func (r *Inputer) fetchAuthEvents(
} }
// Finally, store the event in the database. // Finally, store the event in the database.
eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
// Now we know about this event, it was stored and the signatures were OK. // Now we know about this event, it was stored and the signatures were OK.
@ -471,6 +475,7 @@ func (r *Inputer) fetchAuthEvents(
func (r *Inputer) calculateAndSetState( func (r *Inputer) calculateAndSetState(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
input *api.InputRoomEvent, input *api.InputRoomEvent,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
stateAtEvent *types.StateAtEvent, stateAtEvent *types.StateAtEvent,
@ -485,7 +490,7 @@ func (r *Inputer) calculateAndSetState(
stateAtEvent.Overwrite = true stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID var joinEventNIDs []types.EventNID
// Request join memberships only for local users only. // Request join memberships only for local users only.
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
// If we have no local users that are joined to the room then any state about // If we have no local users that are joined to the room then any state about
// the room that we have is quite possibly out of date. Therefore in that case // the room that we have is quite possibly out of date. Therefore in that case
// we should overwrite it rather than merge it. // we should overwrite it rather than merge it.
@ -495,13 +500,13 @@ func (r *Inputer) calculateAndSetState(
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
} }
entries = types.DeduplicateStateEntries(entries) entries = types.DeduplicateStateEntries(entries)
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
return fmt.Errorf("r.DB.AddState: %w", err) return fmt.Errorf("updater.AddState: %w", err)
} }
} else { } else {
stateAtEvent.Overwrite = false stateAtEvent.Overwrite = false
@ -512,7 +517,7 @@ func (r *Inputer) calculateAndSetState(
} }
} }
err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.SetState: %w", err) return fmt.Errorf("r.DB.SetState: %w", err)
} }

View file

@ -20,7 +20,6 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -48,6 +47,7 @@ import (
// Can only be called once at a time // Can only be called once at a time
func (r *Inputer) updateLatestEvents( func (r *Inputer) updateLatestEvents(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
stateAtEvent types.StateAtEvent, stateAtEvent types.StateAtEvent,
event *gomatrixserverlib.Event, event *gomatrixserverlib.Event,
@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents(
transactionID *api.TransactionID, transactionID *api.TransactionID,
rewritesState bool, rewritesState bool,
) (err error) { ) (err error) {
updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
if err != nil {
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
}
succeeded := false
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
u := latestEventsUpdater{ u := latestEventsUpdater{
ctx: ctx, ctx: ctx,
api: r, api: r,
@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents(
return fmt.Errorf("u.doUpdateLatestEvents: %w", err) return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
} }
succeeded = true
return return
} }
@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents(
type latestEventsUpdater struct { type latestEventsUpdater struct {
ctx context.Context ctx context.Context
api *Inputer api *Inputer
updater *shared.LatestEventsUpdater updater *shared.RoomUpdater
roomInfo *types.RoomInfo roomInfo *types.RoomInfo
stateAtEvent types.StateAtEvent stateAtEvent types.StateAtEvent
event *gomatrixserverlib.Event event *gomatrixserverlib.Event

View file

@ -31,7 +31,7 @@ import (
// consumers about the invites added or retired by the change in current state. // consumers about the invites added or retired by the change in current state.
func (r *Inputer) updateMemberships( func (r *Inputer) updateMemberships(
ctx context.Context, ctx context.Context,
updater *shared.LatestEventsUpdater, updater *shared.RoomUpdater,
removed, added []types.StateEntry, removed, added []types.StateEntry,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
changes := membershipChanges(removed, added) changes := membershipChanges(removed, added)
@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships(
} }
func (r *Inputer) updateMembership( func (r *Inputer) updateMembership(
updater *shared.LatestEventsUpdater, updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event, remove, add *gomatrixserverlib.Event,
updates []api.OutputEvent, updates []api.OutputEvent,

View file

@ -11,7 +11,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -19,7 +19,8 @@ import (
type missingStateReq struct { type missingStateReq struct {
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
db storage.Database //db storage.Database
db *shared.RoomUpdater
inputer *Inputer inputer *Inputer
queryer *query.Queryer queryer *query.Queryer
keys gomatrixserverlib.JSONVerifier keys gomatrixserverlib.JSONVerifier
@ -78,7 +79,7 @@ func (t *missingStateReq) processEventWithMissingState(
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
// in the gap in the DAG // in the gap in the DAG
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -187,7 +188,7 @@ func (t *missingStateReq) processEventWithMissingState(
} }
// TODO: we could do this concurrently? // TODO: we could do this concurrently?
for _, ire := range outlierRoomEvents { for _, ire := range outlierRoomEvents {
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { if err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil {
return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err) return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err)
} }
} }
@ -200,7 +201,7 @@ func (t *missingStateReq) processEventWithMissingState(
stateIDs = append(stateIDs, event.EventID()) stateIDs = append(stateIDs, event.EventID())
} }
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: backwardsExtremity.Headered(roomVersion), Event: backwardsExtremity.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -217,7 +218,7 @@ func (t *missingStateReq) processEventWithMissingState(
// they will automatically fast-forward based on the room state at the // they will automatically fast-forward based on the room state at the
// extremity in the last step. // extremity in the last step.
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,

View file

@ -86,11 +86,10 @@ type Database interface {
// Lookup the event IDs for a batch of event numeric IDs. // Lookup the event IDs for a batch of event numeric IDs.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// Look up the latest events in a room in preparation for an update. // Opens and returns a room updater, which locks the room and opens a transaction.
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error.
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
// If this returns an error then no further action is required. // If this returns an error then no further action is required.
GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) GetRoomUpdater(ctx context.Context, roomInfo types.RoomInfo) (*shared.RoomUpdater, error)
// Look up event references for the latest events in the room and the current state snapshot. // Look up event references for the latest events in the room and the current state snapshot.
// Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns the latest events, the current state and the maximum depth of the latest events plus 1.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.

View file

@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON(
} }
func (s *eventJSONStatements) BulkSelectEventJSON( func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) { ) ([]tables.EventJSONPair, error) {
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID( func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByNID lookups a list of state events by event NID. // bulkSelectStateEventByNID lookups a list of state events by event NID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID( func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples) tuples := stateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples) sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) BulkSelectStateAtEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference(
} }
// bulkSelectEventID returns a map from numeric event ID to string event ID. // bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
} }
func (s *eventStatements) SelectRoomNIDsForEventNIDs( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) { ) (map[types.EventNID]types.RoomNID, error) {
rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
} }
func (s *inviteStatements) InsertInviteEvent( func (s *inviteStatements) InsertInviteEvent(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent(
} }
func (s *inviteStatements) UpdateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) { ) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) { ) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {

View file

@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
} }
func (s *membershipStatements) InsertMembership( func (s *membershipStatements) InsertMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool, localTarget bool,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership(
} }
func (s *membershipStatements) SelectMembershipForUpdate( func (s *membershipStatements) SelectMembershipForUpdate(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership tables.MembershipState, err error) { ) (membership tables.MembershipState, err error) {
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate(
} }
func (s *membershipStatements) SelectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten) ).Scan(&membership, &eventNID, &forgotten)
return return
} }
func (s *membershipStatements) SelectMembershipsFromRoom( func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, roomNID types.RoomNID, localOnly bool, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt var stmt *sql.Stmt
if localOnly { if localOnly {
@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else { } else {
stmt = s.selectMembershipsFromRoomStmt stmt = s.selectMembershipsFromRoomStmt
} }
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID) rows, err := stmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return return
@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} }
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var rows *sql.Rows var rows *sql.Rows
@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else { } else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt stmt = s.selectMembershipsFromRoomAndMembershipStmt
} }
stmt = sqlutil.TxStmt(txn, stmt)
rows, err = stmt.QueryContext(ctx, roomNID, membership) rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return
@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} }
func (s *membershipStatements) UpdateMembership( func (s *membershipStatements) UpdateMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool, eventNID types.EventNID, forgotten bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership(
} }
func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) { ) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs)) roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs { for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i]) roomIDarray[i] = int64(roomNIDs[i])
} }
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err() return result, rows.Err()
} }
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { func (s *membershipStatements) SelectKnownUsers(
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, searchString string, limit int,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
} }
func (s *membershipStatements) UpdateForgetMembership( func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool,
forget bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, roomNID, targetUserNID, forget, ctx, roomNID, targetUserNID, forget,
@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership(
return err return err
} }
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (s *membershipStatements) SelectLocalServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil return found, nil
} }
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { func (s *membershipStatements) SelectServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, serverName gomatrixserverlib.ServerName,
) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil

View file

@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished(
} }
func (s *publishedStatements) SelectPublishedFromRoomID( func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) { ) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
} }
@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool, ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err := stmt.QueryContext(ctx, published)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
} }
func (s *roomAliasesStatements) SelectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) { ) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
} }
func (s *roomAliasesStatements) SelectAliasesFromRoomID( func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
} }
func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) { ) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }

View file

@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx) stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID(
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo var info types.RoomInfo
var latestNIDs pq.Int64Array var latestNIDs pq.Int64Array
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs, &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs(
) ([]types.EventNID, types.StateSnapshotNID, error) { ) ([]types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array var nids pq.Int64Array
var stateSnapshotNID int64 var stateSnapshotNID int64
stmt := s.selectLatestEventNIDsStmt stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs(
} }
func (s *roomStatements) SelectRoomVersionsForRoomNIDs( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID, ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt)
rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
var array pq.Int64Array var array pq.Int64Array
for _, nid := range roomNIDs { for _, nid := range roomNIDs {
array = append(array, int64(nid)) array = append(array, int64(nid))
} }
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx, array)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
var array pq.StringArray var array pq.StringArray
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
array = append(array, roomID) array = append(array, roomID)
} }
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt)
rows, err := stmt.QueryContext(ctx, array)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
} }
func (s *stateBlockStatements) BulkInsertStateData( func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx,
entries types.StateEntries, entries types.StateEntries,
) (id types.StateBlockNID, err error) { ) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)] entries = entries[:util.SortAndUnique(entries)]
@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData(
for _, e := range entries { for _, e := range entries {
nids = append(nids, e.EventNID) nids = append(nids, e.EventNID)
} }
err = s.insertStateDataStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
ctx, nids.Hash(), eventNIDsAsArray(nids), ctx, nids.Hash(), eventNIDsAsArray(nids),
).Scan(&id) ).Scan(&id)
return return
} }
func (s *stateBlockStatements) BulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs types.StateBlockNIDs, ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) { ) ([][]types.EventNID, error) {
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState(
} }
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs)) nids := make([]int64, len(stateNIDs))
for i := range stateNIDs { for i := range stateNIDs {
nids[i] = int64(stateNIDs[i]) nids[i] = int64(stateNIDs[i])
} }
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -9,7 +9,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type LatestEventsUpdater struct { type RoomUpdater struct {
transaction transaction
d *Database d *Database
roomInfo types.RoomInfo roomInfo types.RoomInfo
@ -25,7 +25,7 @@ func rollback(txn *sql.Tx) {
txn.Rollback() // nolint: errcheck txn.Rollback() // nolint: errcheck
} }
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*RoomUpdater, error) {
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil { if err != nil {
@ -45,33 +45,33 @@ func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomI
return nil, err return nil, err
} }
} }
return &LatestEventsUpdater{ return &RoomUpdater{
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil }, nil
} }
// RoomVersion implements types.RoomRecentEventsUpdater // RoomVersion implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
return u.roomInfo.RoomVersion return u.roomInfo.RoomVersion
} }
// LatestEvents implements types.RoomRecentEventsUpdater // LatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference { func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference {
return u.latestEvents return u.latestEvents
} }
// LastEventIDSent implements types.RoomRecentEventsUpdater // LastEventIDSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LastEventIDSent() string { func (u *RoomUpdater) LastEventIDSent() string {
return u.lastEventIDSent return u.lastEventIDSent
} }
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater // CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID return u.currentStateSnapshotNID
} }
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer // StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences { for _, ref := range previousEventReferences {
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
@ -80,8 +80,58 @@ func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previ
return nil return nil
} }
func (u *RoomUpdater) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return u.d.storeEvent(ctx, u.txn, event, authEventNIDs, isRejected)
}
func (u *RoomUpdater) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state)
}
func (u *RoomUpdater) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID)
})
}
func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return u.d.roomInfo(ctx, u.txn, roomID)
}
func (u *RoomUpdater) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
}
// IsReferenced implements types.RoomRecentEventsUpdater // IsReferenced implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil { if err == nil {
return true, nil return true, nil
@ -93,7 +143,7 @@ func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.Even
} }
// SetLatestEvents implements types.RoomRecentEventsUpdater // SetLatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) SetLatestEvents( func (u *RoomUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID, currentStateSnapshotNID types.StateSnapshotNID,
) error { ) error {
@ -117,17 +167,17 @@ func (u *LatestEventsUpdater) SetLatestEvents(
} }
// HasEventBeenSent implements types.RoomRecentEventsUpdater // HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
} }
// MarkEventAsSent implements types.RoomRecentEventsUpdater // MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
}) })
} }
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
} }

View file

@ -42,7 +42,7 @@ type Database struct {
MembershipTable tables.Membership MembershipTable tables.Membership
PublishedTable tables.Published PublishedTable tables.Published
RedactionsTable tables.Redactions RedactionsTable tables.Redactions
GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) GetRoomUpdaterFn func(ctx context.Context, roomInfo types.RoomInfo) (*RoomUpdater, error)
} }
func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) SupportsConcurrentRoomInputs() bool {
@ -108,7 +108,7 @@ func (d *Database) EventStateKeyNIDs(
func (d *Database) StateEntriesForEventIDs( func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs)
} }
func (d *Database) StateEntriesForTuples( func (d *Database) StateEntriesForTuples(
@ -117,14 +117,14 @@ func (d *Database) StateEntriesForTuples(
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs, ctx, nil, stateBlockNIDs,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
} }
lists := []types.StateEntryList{} lists := []types.StateEntryList{}
for i, entry := range entries { for i, entry := range entries {
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, nil, entry, stateKeyTuples)
if err != nil { if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
} }
@ -137,10 +137,14 @@ func (d *Database) StateEntriesForTuples(
} }
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.roomInfo(ctx, nil, roomID)
}
func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil return &roomInfo, nil
} }
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID) roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
if err == nil && roomInfo != nil { if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, *roomInfo) d.Cache.StoreRoomInfo(roomID, *roomInfo)
@ -153,13 +157,22 @@ func (d *Database) AddState(
roomNID types.RoomNID, roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID, stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry, state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return d.addState(ctx, nil, roomNID, stateBlockNIDs, state)
}
func (d *Database) addState(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
if len(stateBlockNIDs) > 0 && len(state) > 0 { if len(stateBlockNIDs) > 0 && len(state) > 0 {
// Check to see if the event already appears in any of the existing state // Check to see if the event already appears in any of the existing state
// blocks. If it does then we should not add it again, as this will just // blocks. If it does then we should not add it again, as this will just
// result in excess state blocks and snapshots. // result in excess state blocks and snapshots.
// TODO: Investigate why this is happening - probably input_events.go! // TODO: Investigate why this is happening - probably input_events.go!
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
if berr != nil { if berr != nil {
return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
} }
@ -180,7 +193,7 @@ func (d *Database) AddState(
} }
} }
} }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
if len(state) > 0 { if len(state) > 0 {
// If there's any state left to add then let's add new blocks. // If there's any state left to add then let's add new blocks.
var stateBlockNID types.StateBlockNID var stateBlockNID types.StateBlockNID
@ -205,7 +218,13 @@ func (d *Database) AddState(
func (d *Database) EventNIDs( func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) { ) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) return d.eventNIDs(ctx, nil, eventIDs)
}
func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
} }
func (d *Database) SetState( func (d *Database) SetState(
@ -219,7 +238,7 @@ func (d *Database) SetState(
func (d *Database) StateAtEventIDs( func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
} }
func (d *Database) SnapshotNIDFromEventID( func (d *Database) SnapshotNIDFromEventID(
@ -232,11 +251,15 @@ func (d *Database) SnapshotNIDFromEventID(
func (d *Database) EventIDs( func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) { ) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
} }
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs) return d.eventsFromIDs(ctx, nil, eventIDs)
}
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -246,7 +269,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type
nids = append(nids, nid) nids = append(nids, nid)
} }
return d.Events(ctx, nids) return d.events(ctx, txn, nids)
} }
func (d *Database) LatestEventIDs( func (d *Database) LatestEventIDs(
@ -271,21 +294,21 @@ func (d *Database) LatestEventIDs(
func (d *Database) StateBlockNIDs( func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, nil, stateNIDs)
} }
func (d *Database) StateEntries( func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID, ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs, ctx, nil, stateBlockNIDs,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
} }
lists := make([]types.StateEntryList, 0, len(entries)) lists := make([]types.StateEntryList, 0, len(entries))
for i, entry := range entries { for i, entry := range entries {
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, nil, entry, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
} }
@ -304,17 +327,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string
} }
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias)
} }
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID)
} }
func (d *Database) GetCreatorIDForAlias( func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (string, error) { ) (string, error) {
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias)
} }
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
@ -335,7 +358,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
senderMembershipEventNID, senderMembership, isRoomforgotten, err := senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
d.MembershipTable.SelectMembershipFromRoomAndTarget( d.MembershipTable.SelectMembershipFromRoomAndTarget(
ctx, roomNID, requestSenderUserNID, ctx, nil, roomNID, requestSenderUserNID,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// The user has never been a member of that room // The user has never been a member of that room
@ -349,14 +372,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly)
}
func (d *Database) getMembershipEventNIDsForRoom(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
if joinOnly { if joinOnly {
return d.MembershipTable.SelectMembershipsFromRoomAndMembership( return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, tables.MembershipStateJoin, localOnly, ctx, txn, roomNID, tables.MembershipStateJoin, localOnly,
) )
} }
return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
} }
func (d *Database) GetInvitesForUser( func (d *Database) GetInvitesForUser(
@ -364,22 +393,28 @@ func (d *Database) GetInvitesForUser(
roomNID types.RoomNID, roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) { ) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
} }
func (d *Database) Events( func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) { ) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) return d.events(ctx, nil, eventNIDs)
}
func (d *Database) events(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
var roomNIDs map[types.EventNID]types.RoomNID var roomNIDs map[types.EventNID]types.RoomNID
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs) roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -398,7 +433,7 @@ func (d *Database) Events(
} }
fetchNIDList = append(fetchNIDList, n) fetchNIDList = append(fetchNIDList, n)
} }
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList) dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -440,19 +475,19 @@ func (d *Database) MembershipUpdater(
return updater, err return updater, err
} }
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetRoomUpdater(
ctx context.Context, roomInfo types.RoomInfo, ctx context.Context, roomInfo types.RoomInfo,
) (*LatestEventsUpdater, error) { ) (*RoomUpdater, error) {
if d.GetLatestEventsForUpdateFn != nil { if d.GetRoomUpdaterFn != nil {
return d.GetLatestEventsForUpdateFn(ctx, roomInfo) return d.GetRoomUpdaterFn(ctx, roomInfo)
} }
txn, err := d.DB.Begin() txn, err := d.DB.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var updater *LatestEventsUpdater var updater *RoomUpdater
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) updater, err = NewRoomUpdater(ctx, d, txn, roomInfo)
return err return err
}) })
return updater, err return updater, err
@ -461,6 +496,13 @@ func (d *Database) GetLatestEventsForUpdate(
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event, ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool, authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
}
func (d *Database) storeEvent(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var ( var (
roomNID types.RoomNID roomNID types.RoomNID
@ -473,7 +515,7 @@ func (d *Database) StoreEvent(
err error err error
) )
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
// TODO: Here we should aim to have two different code paths for new rooms // TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones. // vs existing ones.
@ -547,7 +589,7 @@ func (d *Database) StoreEvent(
// that there's a row-level lock on the latest room events (well, // that there's a row-level lock on the latest room events (well,
// on Postgres at least). // on Postgres at least).
var roomInfo *types.RoomInfo var roomInfo *types.RoomInfo
var updater *LatestEventsUpdater var updater *RoomUpdater
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
roomInfo, err = d.RoomInfo(ctx, event.RoomID()) roomInfo, err = d.RoomInfo(ctx, event.RoomID())
if err != nil { if err != nil {
@ -561,7 +603,7 @@ func (d *Database) StoreEvent(
// function only does SELECTs though so the created txn (at this point) is just a read txn like // function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`. // to do writes however then this will need to go inside `Writer.Do`.
updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) updater, err = d.GetRoomUpdater(ctx, *roomInfo)
if err != nil { if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err)
} }
@ -603,7 +645,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
} }
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, true) return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
} }
func (d *Database) assignRoomNID( func (d *Database) assignRoomNID(
@ -875,14 +917,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
eventNIDs = append(eventNIDs, e.EventNID) eventNIDs = append(eventNIDs, e.EventNID)
} }
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
// return the event requested // return the event requested
for _, e := range entries { for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -922,11 +964,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
} }
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
} }
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
} }
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
} }
@ -999,11 +1041,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
} }
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err) return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
} }
@ -1027,11 +1069,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1057,12 +1099,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID) return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
} }
// GetServerInRoom returns true if we think a server is in a given room or false otherwise. // GetServerInRoom returns true if we think a server is in a given room or false otherwise.
func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName) return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName)
} }
// GetKnownUsers searches all users that userID knows about. // GetKnownUsers searches all users that userID knows about.
@ -1071,17 +1113,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
} }
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx) return d.RoomsTable.SelectRoomIDs(ctx, nil)
} }
// ForgetRoom sets a users room to forgotten // ForgetRoom sets a users room to forgotten
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID}) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID})
if err != nil { if err != nil {
return err return err
} }

View file

@ -76,15 +76,19 @@ func (s *eventJSONStatements) InsertEventJSON(
} }
func (s *eventJSONStatements) BulkSelectEventJSON( func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) { ) ([]tables.EventJSONPair, error) {
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs { for k, v := range eventNIDs {
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil {
return nil, err
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID( func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID( func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples) tuples := stateKeyTupleSorter(stateKeyTuples)
@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
if err != nil { if err != nil {
return nil, fmt.Errorf("s.db.Prepare: %w", err) return nil, fmt.Errorf("s.db.Prepare: %w", err)
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, params...) rows, err := selectStmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) BulkSelectStateAtEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil { if err != nil {
@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectPrep = sqlutil.TxStmt(txn, selectPrep)
////////////// //////////////
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference(
} }
// bulkSelectEventID returns a map from numeric event ID to string event ID. // bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
/////////////// ///////////////
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs { for k, v := range eventNIDs {
@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs { for k, v := range eventIDs {
@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil { if err != nil {
@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
} }
func (s *eventStatements) SelectRoomNIDsForEventNIDs( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) { ) (map[types.EventNID]types.RoomNID, error) {
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr) sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs { for i, v := range eventNIDs {
iEventNIDs[i] = v iEventNIDs[i] = v

View file

@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
} }
func (s *inviteStatements) InsertInviteEvent( func (s *inviteStatements) InsertInviteEvent(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent(
} }
func (s *inviteStatements) UpdateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) { ) (eventIDs []string, err error) {
// gather all the event IDs we will retire // gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs // selectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) { ) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {

View file

@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
} }
func (s *membershipStatements) SelectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten) ).Scan(&membership, &eventNID, &forgotten)
return return
} }
func (s *membershipStatements) SelectMembershipsFromRoom( func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var selectStmt *sql.Stmt var selectStmt *sql.Stmt
@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else { } else {
selectStmt = s.selectMembershipsFromRoomStmt selectStmt = s.selectMembershipsFromRoomStmt
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, roomNID) rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} }
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt var stmt *sql.Stmt
@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else { } else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt stmt = s.selectMembershipsFromRoomAndMembershipStmt
} }
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID, membership) rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return
@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership(
} }
func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) { ) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -276,13 +280,17 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
} }
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) stmt, err := s.db.Prepare(query)
if err != nil {
return nil, err
}
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, iRoomNIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -299,8 +307,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err() return result, rows.Err()
} }
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -317,8 +326,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
} }
func (s *membershipStatements) UpdateForgetMembership( func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool, forget bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
@ -327,9 +336,10 @@ func (s *membershipStatements) UpdateForgetMembership(
return err return err
} }
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
@ -340,9 +350,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil return found, nil
} }
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil

View file

@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished(
} }
func (s *publishedStatements) SelectPublishedFromRoomID( func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) { ) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
} }
@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool, ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err := stmt.QueryContext(ctx, published)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
} }
func (s *roomAliasesStatements) SelectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) { ) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
} }
func (s *roomAliasesStatements) SelectAliasesFromRoomID( func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (aliases []string, err error) { ) (aliases []string, err error) {
aliases = []string{} aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return
} }
@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
} }
func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) { ) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }

View file

@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx) stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo var info types.RoomInfo
var latestNIDsJSON string var latestNIDsJSON string
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON, &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
) )
if err != nil { if err != nil {
@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs(
} }
func (s *roomStatements) SelectRoomVersionsForRoomNIDs( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID, ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr) sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
@ -252,13 +255,18 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
} }
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) sqlPrep, err := s.db.Prepare(sqlQuery)
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, sqlPrep)
rows, err := stmt.QueryContext(ctx, iRoomNIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -274,13 +282,18 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs { for i, v := range roomIDs {
iRoomIDs[i] = v iRoomIDs[i] = v
} }
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) sqlPrep, err := s.db.Prepare(sqlQuery)
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, sqlPrep)
rows, err := stmt.QueryContext(ctx, iRoomIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
} }
func (s *stateBlockStatements) BulkInsertStateData( func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx,
entries types.StateEntries, entries types.StateEntries,
) (id types.StateBlockNID, err error) { ) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)] entries = entries[:util.SortAndUnique(entries)]
@ -101,7 +100,7 @@ func (s *stateBlockStatements) BulkInsertStateData(
} }
func (s *stateBlockStatements) BulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs types.StateBlockNIDs, ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) { ) ([][]types.EventNID, error) {
intfs := make([]interface{}, len(stateBlockNIDs)) intfs := make([]interface{}, len(stateBlockNIDs))
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
@ -112,6 +111,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, intfs...) rows, err := selectStmt.QueryContext(ctx, intfs...)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState(
} }
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
nids := make([]interface{}, len(stateNIDs)) nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs { for k, v := range stateNIDs {
@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, nids...) rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil { if err != nil {

View file

@ -188,7 +188,7 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
MembershipTable: membership, MembershipTable: membership,
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions, RedactionsTable: redactions,
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, GetRoomUpdaterFn: d.GetRoomUpdater,
} }
return nil return nil
} }
@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
return false return false
} }
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetRoomUpdater(
ctx context.Context, roomInfo types.RoomInfo, ctx context.Context, roomInfo types.RoomInfo,
) (*shared.LatestEventsUpdater, error) { ) (*shared.RoomUpdater, error) {
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have // TODO: Do not use transactions. We should be holding open this transaction but we cannot have
// multiple write transactions on sqlite. The code will perform additional // multiple write transactions on sqlite. The code will perform additional
// write transactions independent of this one which will consistently cause // write transactions independent of this one which will consistently cause
// 'database is locked' errors. As sqlite doesn't support multi-process on the // 'database is locked' errors. As sqlite doesn't support multi-process on the
// same DB anyway, and we only execute updates sequentially, the only worries // same DB anyway, and we only execute updates sequentially, the only worries
// are for rolling back when things go wrong. (atomicity) // are for rolling back when things go wrong. (atomicity)
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo)
} }
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(

View file

@ -18,7 +18,7 @@ type EventJSONPair struct {
type EventJSON interface { type EventJSON interface {
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error
BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error)
} }
type EventTypes interface { type EventTypes interface {
@ -42,12 +42,12 @@ type Events interface {
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error)
BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error)
UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error)
UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error
@ -55,12 +55,12 @@ type Events interface {
BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error)
BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error)
// BulkSelectEventID returns a map from numeric event ID to string event ID. // BulkSelectEventID returns a map from numeric event ID to string event ID.
BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
} }
type Rooms interface { type Rooms interface {
@ -69,29 +69,29 @@ type Rooms interface {
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context) ([]string, error) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
} }
type StateSnapshot interface { type StateSnapshot interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
} }
type StateBlock interface { type StateBlock interface {
BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error)
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
//BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
} }
type RoomAliases interface { type RoomAliases interface {
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error)
SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error)
SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error)
DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error) DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error)
} }
@ -106,7 +106,7 @@ type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids. // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
} }
type MembershipState int64 type MembershipState int64
@ -121,24 +121,24 @@ const (
type Membership interface { type Membership interface {
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined. // counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
} }
type Published interface { type Published interface {
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error)
SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error) SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error)
SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error) SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error)
} }
type RedactionInfo struct { type RedactionInfo struct {