Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input

This commit is contained in:
Neil Alexander 2022-02-01 12:52:37 +00:00
parent 893aa3b141
commit f2c0bb165e
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
28 changed files with 490 additions and 294 deletions

View file

@ -19,12 +19,14 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
"github.com/Arceliar/phony" "github.com/Arceliar/phony"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
fedapi "github.com/matrix-org/dendrite/federationapi/api" fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/acls"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
@ -101,7 +103,7 @@ func (r *Inputer) Start() error {
_ = msg.InProgress() // resets the acknowledgement wait timer _ = msg.InProgress() // resets the acknowledgement wait timer
defer eventsInProgress.Delete(index) defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil { if err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent); err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err) sentry.CaptureException(err)
} }
@ -131,6 +133,28 @@ func (r *Inputer) Start() error {
return err return err
} }
func (r *Inputer) processRoomEventUsingUpdater(
ctx context.Context,
roomID string,
inputRoomEvent *api.InputRoomEvent,
) error {
roomInfo, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err)
}
updater, err := r.DB.GetRoomUpdater(ctx, *roomInfo)
if err != nil {
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
}
succeeded := false
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
if err = r.processRoomEvent(ctx, updater, inputRoomEvent); err != nil {
return fmt.Errorf("r.processRoomEvent: %w", err)
}
succeeded = true
return nil
}
// InputRoomEvents implements api.RoomserverInternalAPI // InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents( func (r *Inputer) InputRoomEvents(
ctx context.Context, ctx context.Context,
@ -178,7 +202,7 @@ func (r *Inputer) InputRoomEvents(
worker.Act(nil, func() { worker.Act(nil, func() {
defer eventsInProgress.Delete(index) defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
err := r.processRoomEvent(ctx, &inputRoomEvent) err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent)
if err != nil { if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err) sentry.CaptureException(err)

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -67,6 +68,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
// nolint:gocyclo // nolint:gocyclo
func (r *Inputer) processRoomEvent( func (r *Inputer) processRoomEvent(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
input *api.InputRoomEvent, input *api.InputRoomEvent,
) (err error) { ) (err error) {
select { select {
@ -107,7 +109,7 @@ func (r *Inputer) processRoomEvent(
// if we have already got this event then do not process it again, if the input kind is an outlier. // if we have already got this event then do not process it again, if the input kind is an outlier.
// Outliers contain no extra information which may warrant a re-processing. // Outliers contain no extra information which may warrant a re-processing.
if input.Kind == api.KindOutlier { if input.Kind == api.KindOutlier {
evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()})
if err2 == nil && len(evs) == 1 { if err2 == nil && len(evs) == 1 {
// check hash matches if we're on early room versions where the event ID was a random string // check hash matches if we're on early room versions where the event ID was a random string
idFormat, err2 := headered.RoomVersion.EventIDFormat() idFormat, err2 := headered.RoomVersion.EventIDFormat()
@ -176,7 +178,7 @@ func (r *Inputer) processRoomEvent(
isRejected := false isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{} knownEvents := map[string]*types.Event{}
if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { if err = r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err) return fmt.Errorf("r.fetchAuthEvents: %w", err)
} }
@ -227,7 +229,7 @@ func (r *Inputer) processRoomEvent(
origin: input.Origin, origin: input.Origin,
inputer: r, inputer: r,
queryer: r.Queryer, queryer: r.Queryer,
db: r.DB, db: updater,
federation: r.FSAPI, federation: r.FSAPI,
keys: r.KeyRing, keys: r.KeyRing,
roomsMu: internal.NewMutexByRoom(), roomsMu: internal.NewMutexByRoom(),
@ -248,9 +250,9 @@ func (r *Inputer) processRoomEvent(
} }
// Store the event. // Store the event.
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
// if storing this event results in it being redacted then do so. // if storing this event results in it being redacted then do so.
@ -271,18 +273,18 @@ func (r *Inputer) processRoomEvent(
return nil return nil
} }
roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) roomInfo, err := updater.RoomInfo(ctx, event.RoomID())
if err != nil { if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err) return fmt.Errorf("updater.RoomInfo: %w", err)
} }
if roomInfo == nil { if roomInfo == nil {
return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
} }
if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet. // We haven't calculated a state for this event yet.
// Lets calculate one. // Lets calculate one.
err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("r.calculateAndSetState: %w", err) return fmt.Errorf("r.calculateAndSetState: %w", err)
} }
@ -301,6 +303,7 @@ func (r *Inputer) processRoomEvent(
case api.KindNew: case api.KindNew:
if err = r.updateLatestEvents( if err = r.updateLatestEvents(
ctx, // context ctx, // context
updater, // room updater
roomInfo, // room info for the room being updated roomInfo, // room info for the room being updated
stateAtEvent, // state at event (below) stateAtEvent, // state at event (below)
event, // event event, // event
@ -358,6 +361,7 @@ func (r *Inputer) processRoomEvent(
// they are now in the database. // they are now in the database.
func (r *Inputer) fetchAuthEvents( func (r *Inputer) fetchAuthEvents(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
logger *logrus.Entry, logger *logrus.Entry,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
auth *gomatrixserverlib.AuthEvents, auth *gomatrixserverlib.AuthEvents,
@ -375,7 +379,7 @@ func (r *Inputer) fetchAuthEvents(
} }
for _, authEventID := range authEventIDs { for _, authEventID := range authEventIDs {
authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{} unknown[authEventID] = struct{}{}
continue continue
@ -454,9 +458,9 @@ func (r *Inputer) fetchAuthEvents(
} }
// Finally, store the event in the database. // Finally, store the event in the database.
eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
// Now we know about this event, it was stored and the signatures were OK. // Now we know about this event, it was stored and the signatures were OK.
@ -471,6 +475,7 @@ func (r *Inputer) fetchAuthEvents(
func (r *Inputer) calculateAndSetState( func (r *Inputer) calculateAndSetState(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
input *api.InputRoomEvent, input *api.InputRoomEvent,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
stateAtEvent *types.StateAtEvent, stateAtEvent *types.StateAtEvent,
@ -485,7 +490,7 @@ func (r *Inputer) calculateAndSetState(
stateAtEvent.Overwrite = true stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID var joinEventNIDs []types.EventNID
// Request join memberships only for local users only. // Request join memberships only for local users only.
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
// If we have no local users that are joined to the room then any state about // If we have no local users that are joined to the room then any state about
// the room that we have is quite possibly out of date. Therefore in that case // the room that we have is quite possibly out of date. Therefore in that case
// we should overwrite it rather than merge it. // we should overwrite it rather than merge it.
@ -495,13 +500,13 @@ func (r *Inputer) calculateAndSetState(
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
} }
entries = types.DeduplicateStateEntries(entries) entries = types.DeduplicateStateEntries(entries)
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
return fmt.Errorf("r.DB.AddState: %w", err) return fmt.Errorf("updater.AddState: %w", err)
} }
} else { } else {
stateAtEvent.Overwrite = false stateAtEvent.Overwrite = false
@ -512,7 +517,7 @@ func (r *Inputer) calculateAndSetState(
} }
} }
err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.SetState: %w", err) return fmt.Errorf("r.DB.SetState: %w", err)
} }

View file

@ -20,7 +20,6 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -48,6 +47,7 @@ import (
// Can only be called once at a time // Can only be called once at a time
func (r *Inputer) updateLatestEvents( func (r *Inputer) updateLatestEvents(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
stateAtEvent types.StateAtEvent, stateAtEvent types.StateAtEvent,
event *gomatrixserverlib.Event, event *gomatrixserverlib.Event,
@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents(
transactionID *api.TransactionID, transactionID *api.TransactionID,
rewritesState bool, rewritesState bool,
) (err error) { ) (err error) {
updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
if err != nil {
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
}
succeeded := false
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
u := latestEventsUpdater{ u := latestEventsUpdater{
ctx: ctx, ctx: ctx,
api: r, api: r,
@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents(
return fmt.Errorf("u.doUpdateLatestEvents: %w", err) return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
} }
succeeded = true
return return
} }
@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents(
type latestEventsUpdater struct { type latestEventsUpdater struct {
ctx context.Context ctx context.Context
api *Inputer api *Inputer
updater *shared.LatestEventsUpdater updater *shared.RoomUpdater
roomInfo *types.RoomInfo roomInfo *types.RoomInfo
stateAtEvent types.StateAtEvent stateAtEvent types.StateAtEvent
event *gomatrixserverlib.Event event *gomatrixserverlib.Event

View file

@ -31,7 +31,7 @@ import (
// consumers about the invites added or retired by the change in current state. // consumers about the invites added or retired by the change in current state.
func (r *Inputer) updateMemberships( func (r *Inputer) updateMemberships(
ctx context.Context, ctx context.Context,
updater *shared.LatestEventsUpdater, updater *shared.RoomUpdater,
removed, added []types.StateEntry, removed, added []types.StateEntry,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
changes := membershipChanges(removed, added) changes := membershipChanges(removed, added)
@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships(
} }
func (r *Inputer) updateMembership( func (r *Inputer) updateMembership(
updater *shared.LatestEventsUpdater, updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event, remove, add *gomatrixserverlib.Event,
updates []api.OutputEvent, updates []api.OutputEvent,

View file

@ -11,15 +11,16 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
type missingStateReq struct { type missingStateReq struct {
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
db storage.Database //db storage.Database
db *shared.RoomUpdater
inputer *Inputer inputer *Inputer
queryer *query.Queryer queryer *query.Queryer
keys gomatrixserverlib.JSONVerifier keys gomatrixserverlib.JSONVerifier
@ -78,7 +79,7 @@ func (t *missingStateReq) processEventWithMissingState(
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
// in the gap in the DAG // in the gap in the DAG
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -187,7 +188,7 @@ func (t *missingStateReq) processEventWithMissingState(
} }
// TODO: we could do this concurrently? // TODO: we could do this concurrently?
for _, ire := range outlierRoomEvents { for _, ire := range outlierRoomEvents {
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { if err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil {
return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err) return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err)
} }
} }
@ -200,7 +201,7 @@ func (t *missingStateReq) processEventWithMissingState(
stateIDs = append(stateIDs, event.EventID()) stateIDs = append(stateIDs, event.EventID())
} }
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: backwardsExtremity.Headered(roomVersion), Event: backwardsExtremity.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -217,7 +218,7 @@ func (t *missingStateReq) processEventWithMissingState(
// they will automatically fast-forward based on the room state at the // they will automatically fast-forward based on the room state at the
// extremity in the last step. // extremity in the last step.
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,

View file

@ -86,11 +86,10 @@ type Database interface {
// Lookup the event IDs for a batch of event numeric IDs. // Lookup the event IDs for a batch of event numeric IDs.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// Look up the latest events in a room in preparation for an update. // Opens and returns a room updater, which locks the room and opens a transaction.
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error.
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
// If this returns an error then no further action is required. // If this returns an error then no further action is required.
GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) GetRoomUpdater(ctx context.Context, roomInfo types.RoomInfo) (*shared.RoomUpdater, error)
// Look up event references for the latest events in the room and the current state snapshot. // Look up event references for the latest events in the room and the current state snapshot.
// Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns the latest events, the current state and the maximum depth of the latest events plus 1.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.

View file

@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON(
} }
func (s *eventJSONStatements) BulkSelectEventJSON( func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) { ) ([]tables.EventJSONPair, error) {
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID( func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByNID lookups a list of state events by event NID. // bulkSelectStateEventByNID lookups a list of state events by event NID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID( func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples) tuples := stateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples) sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) BulkSelectStateAtEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference(
} }
// bulkSelectEventID returns a map from numeric event ID to string event ID. // bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
} }
func (s *eventStatements) SelectRoomNIDsForEventNIDs( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) { ) (map[types.EventNID]types.RoomNID, error) {
rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
} }
func (s *inviteStatements) InsertInviteEvent( func (s *inviteStatements) InsertInviteEvent(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent(
} }
func (s *inviteStatements) UpdateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) { ) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) { ) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {

View file

@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
} }
func (s *membershipStatements) InsertMembership( func (s *membershipStatements) InsertMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool, localTarget bool,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership(
} }
func (s *membershipStatements) SelectMembershipForUpdate( func (s *membershipStatements) SelectMembershipForUpdate(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership tables.MembershipState, err error) { ) (membership tables.MembershipState, err error) {
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate(
} }
func (s *membershipStatements) SelectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten) ).Scan(&membership, &eventNID, &forgotten)
return return
} }
func (s *membershipStatements) SelectMembershipsFromRoom( func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, roomNID types.RoomNID, localOnly bool, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt var stmt *sql.Stmt
if localOnly { if localOnly {
@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else { } else {
stmt = s.selectMembershipsFromRoomStmt stmt = s.selectMembershipsFromRoomStmt
} }
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID) rows, err := stmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return return
@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} }
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var rows *sql.Rows var rows *sql.Rows
@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else { } else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt stmt = s.selectMembershipsFromRoomAndMembershipStmt
} }
stmt = sqlutil.TxStmt(txn, stmt)
rows, err = stmt.QueryContext(ctx, roomNID, membership) rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return
@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} }
func (s *membershipStatements) UpdateMembership( func (s *membershipStatements) UpdateMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool, eventNID types.EventNID, forgotten bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership(
} }
func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) { ) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs)) roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs { for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i]) roomIDarray[i] = int64(roomNIDs[i])
} }
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err() return result, rows.Err()
} }
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { func (s *membershipStatements) SelectKnownUsers(
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, searchString string, limit int,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
} }
func (s *membershipStatements) UpdateForgetMembership( func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool,
forget bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, roomNID, targetUserNID, forget, ctx, roomNID, targetUserNID, forget,
@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership(
return err return err
} }
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (s *membershipStatements) SelectLocalServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil return found, nil
} }
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { func (s *membershipStatements) SelectServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, serverName gomatrixserverlib.ServerName,
) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil

View file

@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished(
} }
func (s *publishedStatements) SelectPublishedFromRoomID( func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) { ) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
} }
@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool, ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err := stmt.QueryContext(ctx, published)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
} }
func (s *roomAliasesStatements) SelectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) { ) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
} }
func (s *roomAliasesStatements) SelectAliasesFromRoomID( func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
} }
func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) { ) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }

View file

@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx) stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID(
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo var info types.RoomInfo
var latestNIDs pq.Int64Array var latestNIDs pq.Int64Array
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs, &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs(
) ([]types.EventNID, types.StateSnapshotNID, error) { ) ([]types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array var nids pq.Int64Array
var stateSnapshotNID int64 var stateSnapshotNID int64
stmt := s.selectLatestEventNIDsStmt stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs(
} }
func (s *roomStatements) SelectRoomVersionsForRoomNIDs( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID, ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt)
rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
var array pq.Int64Array var array pq.Int64Array
for _, nid := range roomNIDs { for _, nid := range roomNIDs {
array = append(array, int64(nid)) array = append(array, int64(nid))
} }
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx, array)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
var array pq.StringArray var array pq.StringArray
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
array = append(array, roomID) array = append(array, roomID)
} }
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt)
rows, err := stmt.QueryContext(ctx, array)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
} }
func (s *stateBlockStatements) BulkInsertStateData( func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx,
entries types.StateEntries, entries types.StateEntries,
) (id types.StateBlockNID, err error) { ) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)] entries = entries[:util.SortAndUnique(entries)]
@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData(
for _, e := range entries { for _, e := range entries {
nids = append(nids, e.EventNID) nids = append(nids, e.EventNID)
} }
err = s.insertStateDataStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
ctx, nids.Hash(), eventNIDsAsArray(nids), ctx, nids.Hash(), eventNIDsAsArray(nids),
).Scan(&id) ).Scan(&id)
return return
} }
func (s *stateBlockStatements) BulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs types.StateBlockNIDs, ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) { ) ([][]types.EventNID, error) {
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState(
} }
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs)) nids := make([]int64, len(stateNIDs))
for i := range stateNIDs { for i := range stateNIDs {
nids[i] = int64(stateNIDs[i]) nids[i] = int64(stateNIDs[i])
} }
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -9,7 +9,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type LatestEventsUpdater struct { type RoomUpdater struct {
transaction transaction
d *Database d *Database
roomInfo types.RoomInfo roomInfo types.RoomInfo
@ -25,7 +25,7 @@ func rollback(txn *sql.Tx) {
txn.Rollback() // nolint: errcheck txn.Rollback() // nolint: errcheck
} }
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*RoomUpdater, error) {
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil { if err != nil {
@ -45,33 +45,33 @@ func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomI
return nil, err return nil, err
} }
} }
return &LatestEventsUpdater{ return &RoomUpdater{
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil }, nil
} }
// RoomVersion implements types.RoomRecentEventsUpdater // RoomVersion implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
return u.roomInfo.RoomVersion return u.roomInfo.RoomVersion
} }
// LatestEvents implements types.RoomRecentEventsUpdater // LatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference { func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference {
return u.latestEvents return u.latestEvents
} }
// LastEventIDSent implements types.RoomRecentEventsUpdater // LastEventIDSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LastEventIDSent() string { func (u *RoomUpdater) LastEventIDSent() string {
return u.lastEventIDSent return u.lastEventIDSent
} }
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater // CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID return u.currentStateSnapshotNID
} }
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer // StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences { for _, ref := range previousEventReferences {
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
@ -80,8 +80,58 @@ func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previ
return nil return nil
} }
func (u *RoomUpdater) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return u.d.storeEvent(ctx, u.txn, event, authEventNIDs, isRejected)
}
func (u *RoomUpdater) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state)
}
func (u *RoomUpdater) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID)
})
}
func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return u.d.roomInfo(ctx, u.txn, roomID)
}
func (u *RoomUpdater) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
}
// IsReferenced implements types.RoomRecentEventsUpdater // IsReferenced implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil { if err == nil {
return true, nil return true, nil
@ -93,7 +143,7 @@ func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.Even
} }
// SetLatestEvents implements types.RoomRecentEventsUpdater // SetLatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) SetLatestEvents( func (u *RoomUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID, currentStateSnapshotNID types.StateSnapshotNID,
) error { ) error {
@ -117,17 +167,17 @@ func (u *LatestEventsUpdater) SetLatestEvents(
} }
// HasEventBeenSent implements types.RoomRecentEventsUpdater // HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
} }
// MarkEventAsSent implements types.RoomRecentEventsUpdater // MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
}) })
} }
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
} }

View file

@ -26,23 +26,23 @@ import (
const redactionsArePermanent = true const redactionsArePermanent = true
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
Writer sqlutil.Writer Writer sqlutil.Writer
EventsTable tables.Events EventsTable tables.Events
EventJSONTable tables.EventJSON EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites InvitesTable tables.Invites
MembershipTable tables.Membership MembershipTable tables.Membership
PublishedTable tables.Published PublishedTable tables.Published
RedactionsTable tables.Redactions RedactionsTable tables.Redactions
GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) GetRoomUpdaterFn func(ctx context.Context, roomInfo types.RoomInfo) (*RoomUpdater, error)
} }
func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) SupportsConcurrentRoomInputs() bool {
@ -108,7 +108,7 @@ func (d *Database) EventStateKeyNIDs(
func (d *Database) StateEntriesForEventIDs( func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs)
} }
func (d *Database) StateEntriesForTuples( func (d *Database) StateEntriesForTuples(
@ -117,14 +117,14 @@ func (d *Database) StateEntriesForTuples(
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs, ctx, nil, stateBlockNIDs,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
} }
lists := []types.StateEntryList{} lists := []types.StateEntryList{}
for i, entry := range entries { for i, entry := range entries {
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, nil, entry, stateKeyTuples)
if err != nil { if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
} }
@ -137,10 +137,14 @@ func (d *Database) StateEntriesForTuples(
} }
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.roomInfo(ctx, nil, roomID)
}
func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil return &roomInfo, nil
} }
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID) roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
if err == nil && roomInfo != nil { if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, *roomInfo) d.Cache.StoreRoomInfo(roomID, *roomInfo)
@ -153,13 +157,22 @@ func (d *Database) AddState(
roomNID types.RoomNID, roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID, stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry, state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return d.addState(ctx, nil, roomNID, stateBlockNIDs, state)
}
func (d *Database) addState(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
if len(stateBlockNIDs) > 0 && len(state) > 0 { if len(stateBlockNIDs) > 0 && len(state) > 0 {
// Check to see if the event already appears in any of the existing state // Check to see if the event already appears in any of the existing state
// blocks. If it does then we should not add it again, as this will just // blocks. If it does then we should not add it again, as this will just
// result in excess state blocks and snapshots. // result in excess state blocks and snapshots.
// TODO: Investigate why this is happening - probably input_events.go! // TODO: Investigate why this is happening - probably input_events.go!
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
if berr != nil { if berr != nil {
return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
} }
@ -180,7 +193,7 @@ func (d *Database) AddState(
} }
} }
} }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
if len(state) > 0 { if len(state) > 0 {
// If there's any state left to add then let's add new blocks. // If there's any state left to add then let's add new blocks.
var stateBlockNID types.StateBlockNID var stateBlockNID types.StateBlockNID
@ -205,7 +218,13 @@ func (d *Database) AddState(
func (d *Database) EventNIDs( func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) { ) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) return d.eventNIDs(ctx, nil, eventIDs)
}
func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
} }
func (d *Database) SetState( func (d *Database) SetState(
@ -219,7 +238,7 @@ func (d *Database) SetState(
func (d *Database) StateAtEventIDs( func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
} }
func (d *Database) SnapshotNIDFromEventID( func (d *Database) SnapshotNIDFromEventID(
@ -232,11 +251,15 @@ func (d *Database) SnapshotNIDFromEventID(
func (d *Database) EventIDs( func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) { ) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
} }
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs) return d.eventsFromIDs(ctx, nil, eventIDs)
}
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -246,7 +269,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type
nids = append(nids, nid) nids = append(nids, nid)
} }
return d.Events(ctx, nids) return d.events(ctx, txn, nids)
} }
func (d *Database) LatestEventIDs( func (d *Database) LatestEventIDs(
@ -271,21 +294,21 @@ func (d *Database) LatestEventIDs(
func (d *Database) StateBlockNIDs( func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, nil, stateNIDs)
} }
func (d *Database) StateEntries( func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID, ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs, ctx, nil, stateBlockNIDs,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
} }
lists := make([]types.StateEntryList, 0, len(entries)) lists := make([]types.StateEntryList, 0, len(entries))
for i, entry := range entries { for i, entry := range entries {
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, nil, entry, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
} }
@ -304,17 +327,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string
} }
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias)
} }
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID)
} }
func (d *Database) GetCreatorIDForAlias( func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (string, error) { ) (string, error) {
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias)
} }
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
@ -335,7 +358,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
senderMembershipEventNID, senderMembership, isRoomforgotten, err := senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
d.MembershipTable.SelectMembershipFromRoomAndTarget( d.MembershipTable.SelectMembershipFromRoomAndTarget(
ctx, roomNID, requestSenderUserNID, ctx, nil, roomNID, requestSenderUserNID,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// The user has never been a member of that room // The user has never been a member of that room
@ -349,14 +372,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly)
}
func (d *Database) getMembershipEventNIDsForRoom(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
if joinOnly { if joinOnly {
return d.MembershipTable.SelectMembershipsFromRoomAndMembership( return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, tables.MembershipStateJoin, localOnly, ctx, txn, roomNID, tables.MembershipStateJoin, localOnly,
) )
} }
return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
} }
func (d *Database) GetInvitesForUser( func (d *Database) GetInvitesForUser(
@ -364,22 +393,28 @@ func (d *Database) GetInvitesForUser(
roomNID types.RoomNID, roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) { ) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
} }
func (d *Database) Events( func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) { ) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) return d.events(ctx, nil, eventNIDs)
}
func (d *Database) events(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
var roomNIDs map[types.EventNID]types.RoomNID var roomNIDs map[types.EventNID]types.RoomNID
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs) roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -398,7 +433,7 @@ func (d *Database) Events(
} }
fetchNIDList = append(fetchNIDList, n) fetchNIDList = append(fetchNIDList, n)
} }
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList) dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -440,19 +475,19 @@ func (d *Database) MembershipUpdater(
return updater, err return updater, err
} }
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetRoomUpdater(
ctx context.Context, roomInfo types.RoomInfo, ctx context.Context, roomInfo types.RoomInfo,
) (*LatestEventsUpdater, error) { ) (*RoomUpdater, error) {
if d.GetLatestEventsForUpdateFn != nil { if d.GetRoomUpdaterFn != nil {
return d.GetLatestEventsForUpdateFn(ctx, roomInfo) return d.GetRoomUpdaterFn(ctx, roomInfo)
} }
txn, err := d.DB.Begin() txn, err := d.DB.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var updater *LatestEventsUpdater var updater *RoomUpdater
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) updater, err = NewRoomUpdater(ctx, d, txn, roomInfo)
return err return err
}) })
return updater, err return updater, err
@ -461,6 +496,13 @@ func (d *Database) GetLatestEventsForUpdate(
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event, ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool, authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
}
func (d *Database) storeEvent(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var ( var (
roomNID types.RoomNID roomNID types.RoomNID
@ -473,7 +515,7 @@ func (d *Database) StoreEvent(
err error err error
) )
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
// TODO: Here we should aim to have two different code paths for new rooms // TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones. // vs existing ones.
@ -547,7 +589,7 @@ func (d *Database) StoreEvent(
// that there's a row-level lock on the latest room events (well, // that there's a row-level lock on the latest room events (well,
// on Postgres at least). // on Postgres at least).
var roomInfo *types.RoomInfo var roomInfo *types.RoomInfo
var updater *LatestEventsUpdater var updater *RoomUpdater
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
roomInfo, err = d.RoomInfo(ctx, event.RoomID()) roomInfo, err = d.RoomInfo(ctx, event.RoomID())
if err != nil { if err != nil {
@ -561,7 +603,7 @@ func (d *Database) StoreEvent(
// function only does SELECTs though so the created txn (at this point) is just a read txn like // function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`. // to do writes however then this will need to go inside `Writer.Do`.
updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) updater, err = d.GetRoomUpdater(ctx, *roomInfo)
if err != nil { if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err)
} }
@ -603,7 +645,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
} }
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, true) return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
} }
func (d *Database) assignRoomNID( func (d *Database) assignRoomNID(
@ -875,14 +917,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
eventNIDs = append(eventNIDs, e.EventNID) eventNIDs = append(eventNIDs, e.EventNID)
} }
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
// return the event requested // return the event requested
for _, e := range entries { for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -922,11 +964,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
} }
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
} }
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
} }
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
} }
@ -999,11 +1041,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
} }
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err) return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
} }
@ -1027,11 +1069,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1057,12 +1099,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID) return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
} }
// GetServerInRoom returns true if we think a server is in a given room or false otherwise. // GetServerInRoom returns true if we think a server is in a given room or false otherwise.
func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName) return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName)
} }
// GetKnownUsers searches all users that userID knows about. // GetKnownUsers searches all users that userID knows about.
@ -1071,17 +1113,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
} }
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx) return d.RoomsTable.SelectRoomIDs(ctx, nil)
} }
// ForgetRoom sets a users room to forgotten // ForgetRoom sets a users room to forgotten
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID}) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID})
if err != nil { if err != nil {
return err return err
} }

View file

@ -76,15 +76,19 @@ func (s *eventJSONStatements) InsertEventJSON(
} }
func (s *eventJSONStatements) BulkSelectEventJSON( func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) { ) ([]tables.EventJSONPair, error) {
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs { for k, v := range eventNIDs {
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil {
return nil, err
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID( func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID( func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples) tuples := stateKeyTupleSorter(stateKeyTuples)
@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
if err != nil { if err != nil {
return nil, fmt.Errorf("s.db.Prepare: %w", err) return nil, fmt.Errorf("s.db.Prepare: %w", err)
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, params...) rows, err := selectStmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) BulkSelectStateAtEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil { if err != nil {
@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectPrep = sqlutil.TxStmt(txn, selectPrep)
////////////// //////////////
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference(
} }
// bulkSelectEventID returns a map from numeric event ID to string event ID. // bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
/////////////// ///////////////
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs { for k, v := range eventNIDs {
@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs { for k, v := range eventIDs {
@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil { if err != nil {
@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
} }
func (s *eventStatements) SelectRoomNIDsForEventNIDs( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) { ) (map[types.EventNID]types.RoomNID, error) {
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr) sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs { for i, v := range eventNIDs {
iEventNIDs[i] = v iEventNIDs[i] = v

View file

@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
} }
func (s *inviteStatements) InsertInviteEvent( func (s *inviteStatements) InsertInviteEvent(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent(
} }
func (s *inviteStatements) UpdateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) { ) (eventIDs []string, err error) {
// gather all the event IDs we will retire // gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs // selectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) { ) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {

View file

@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
} }
func (s *membershipStatements) SelectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten) ).Scan(&membership, &eventNID, &forgotten)
return return
} }
func (s *membershipStatements) SelectMembershipsFromRoom( func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var selectStmt *sql.Stmt var selectStmt *sql.Stmt
@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else { } else {
selectStmt = s.selectMembershipsFromRoomStmt selectStmt = s.selectMembershipsFromRoomStmt
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, roomNID) rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} }
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt var stmt *sql.Stmt
@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else { } else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt stmt = s.selectMembershipsFromRoomAndMembershipStmt
} }
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID, membership) rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return
@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership(
} }
func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) { ) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -276,13 +280,17 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
} }
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) stmt, err := s.db.Prepare(query)
if err != nil {
return nil, err
}
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, iRoomNIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -299,8 +307,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err() return result, rows.Err()
} }
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -317,8 +326,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
} }
func (s *membershipStatements) UpdateForgetMembership( func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool, forget bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
@ -327,9 +336,10 @@ func (s *membershipStatements) UpdateForgetMembership(
return err return err
} }
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
@ -340,9 +350,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil return found, nil
} }
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil

View file

@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished(
} }
func (s *publishedStatements) SelectPublishedFromRoomID( func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) { ) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
} }
@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool, ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err := stmt.QueryContext(ctx, published)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
} }
func (s *roomAliasesStatements) SelectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) { ) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
} }
func (s *roomAliasesStatements) SelectAliasesFromRoomID( func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (aliases []string, err error) { ) (aliases []string, err error) {
aliases = []string{} aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return
} }
@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
} }
func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) { ) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }

View file

@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx) stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo var info types.RoomInfo
var latestNIDsJSON string var latestNIDsJSON string
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON, &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
) )
if err != nil { if err != nil {
@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs(
} }
func (s *roomStatements) SelectRoomVersionsForRoomNIDs( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID, ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr) sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
@ -252,13 +255,18 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
} }
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) sqlPrep, err := s.db.Prepare(sqlQuery)
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, sqlPrep)
rows, err := stmt.QueryContext(ctx, iRoomNIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -274,13 +282,18 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs { for i, v := range roomIDs {
iRoomIDs[i] = v iRoomIDs[i] = v
} }
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) sqlPrep, err := s.db.Prepare(sqlQuery)
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, sqlPrep)
rows, err := stmt.QueryContext(ctx, iRoomIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
} }
func (s *stateBlockStatements) BulkInsertStateData( func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx,
entries types.StateEntries, entries types.StateEntries,
) (id types.StateBlockNID, err error) { ) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)] entries = entries[:util.SortAndUnique(entries)]
@ -101,7 +100,7 @@ func (s *stateBlockStatements) BulkInsertStateData(
} }
func (s *stateBlockStatements) BulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs types.StateBlockNIDs, ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) { ) ([][]types.EventNID, error) {
intfs := make([]interface{}, len(stateBlockNIDs)) intfs := make([]interface{}, len(stateBlockNIDs))
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
@ -112,6 +111,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, intfs...) rows, err := selectStmt.QueryContext(ctx, intfs...)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState(
} }
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
nids := make([]interface{}, len(stateNIDs)) nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs { for k, v := range stateNIDs {
@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, nids...) rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil { if err != nil {

View file

@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
return err return err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db, DB: db,
Cache: cache, Cache: cache,
Writer: sqlutil.NewExclusiveWriter(), Writer: sqlutil.NewExclusiveWriter(),
EventsTable: events, EventsTable: events,
EventTypesTable: eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys, EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON, EventJSONTable: eventJSON,
RoomsTable: rooms, RoomsTable: rooms,
StateBlockTable: stateBlock, StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: prevEvents, PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: invites, InvitesTable: invites,
MembershipTable: membership, MembershipTable: membership,
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions, RedactionsTable: redactions,
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, GetRoomUpdaterFn: d.GetRoomUpdater,
} }
return nil return nil
} }
@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
return false return false
} }
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetRoomUpdater(
ctx context.Context, roomInfo types.RoomInfo, ctx context.Context, roomInfo types.RoomInfo,
) (*shared.LatestEventsUpdater, error) { ) (*shared.RoomUpdater, error) {
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have // TODO: Do not use transactions. We should be holding open this transaction but we cannot have
// multiple write transactions on sqlite. The code will perform additional // multiple write transactions on sqlite. The code will perform additional
// write transactions independent of this one which will consistently cause // write transactions independent of this one which will consistently cause
// 'database is locked' errors. As sqlite doesn't support multi-process on the // 'database is locked' errors. As sqlite doesn't support multi-process on the
// same DB anyway, and we only execute updates sequentially, the only worries // same DB anyway, and we only execute updates sequentially, the only worries
// are for rolling back when things go wrong. (atomicity) // are for rolling back when things go wrong. (atomicity)
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo)
} }
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(

View file

@ -18,7 +18,7 @@ type EventJSONPair struct {
type EventJSON interface { type EventJSON interface {
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error
BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error)
} }
type EventTypes interface { type EventTypes interface {
@ -42,12 +42,12 @@ type Events interface {
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error)
BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error)
UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error)
UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error
@ -55,12 +55,12 @@ type Events interface {
BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error)
BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error)
// BulkSelectEventID returns a map from numeric event ID to string event ID. // BulkSelectEventID returns a map from numeric event ID to string event ID.
BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
} }
type Rooms interface { type Rooms interface {
@ -69,29 +69,29 @@ type Rooms interface {
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context) ([]string, error) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
} }
type StateSnapshot interface { type StateSnapshot interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
} }
type StateBlock interface { type StateBlock interface {
BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error)
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
//BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
} }
type RoomAliases interface { type RoomAliases interface {
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error)
SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error)
SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error)
DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error) DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error)
} }
@ -106,7 +106,7 @@ type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids. // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
} }
type MembershipState int64 type MembershipState int64
@ -121,24 +121,24 @@ const (
type Membership interface { type Membership interface {
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined. // counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
} }
type Published interface { type Published interface {
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error)
SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error) SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error)
SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error) SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error)
} }
type RedactionInfo struct { type RedactionInfo struct {