mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-29 17:51:56 -06:00
Full roomserver input transactional isolation (#2141)
* Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input
* Better transaction management
* Tweak order
* Handle cases where the room does not exist
* Other fixes
* More tweaks
* Fill some gaps
* Fill in the gaps
* good lord it gets worse
* Don't roll back transactions when events rejected
* Pass through errors properly
* Fix bugs
* Fix incorrect error check
* Don't panic on nil txns
* Tweaks
* Hopefully fix panics for good in SQLite this time
* Fix rollback
* Minor bug fixes with latest event updater
* Some review comments
* Revert "Some review comments"
This reverts commit 0caf8cf53e
.
* Fix a couple of bugs
* Clearer commit and rollback results
* Remove unnecessary prepares
This commit is contained in:
parent
4d9f5b2e57
commit
eb352a5f6b
|
@ -20,17 +20,22 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type checkForAuthAndSoftFailStorage interface {
|
||||||
|
state.StateResolutionStorage
|
||||||
|
StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
|
||||||
|
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
// CheckForSoftFail returns true if the event should be soft-failed
|
// CheckForSoftFail returns true if the event should be soft-failed
|
||||||
// and false otherwise. The return error value should be checked before
|
// and false otherwise. The return error value should be checked before
|
||||||
// the soft-fail bool.
|
// the soft-fail bool.
|
||||||
func CheckForSoftFail(
|
func CheckForSoftFail(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db storage.Database,
|
db checkForAuthAndSoftFailStorage,
|
||||||
event *gomatrixserverlib.HeaderedEvent,
|
event *gomatrixserverlib.HeaderedEvent,
|
||||||
stateEventIDs []string,
|
stateEventIDs []string,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
|
@ -92,7 +97,7 @@ func CheckForSoftFail(
|
||||||
// Returns the numeric IDs for the auth events.
|
// Returns the numeric IDs for the auth events.
|
||||||
func CheckAuthEvents(
|
func CheckAuthEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db storage.Database,
|
db checkForAuthAndSoftFailStorage,
|
||||||
event *gomatrixserverlib.HeaderedEvent,
|
event *gomatrixserverlib.HeaderedEvent,
|
||||||
authEventIDs []string,
|
authEventIDs []string,
|
||||||
) ([]types.EventNID, error) {
|
) ([]types.EventNID, error) {
|
||||||
|
@ -193,7 +198,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
|
||||||
// loadAuthEvents loads the events needed for authentication from the supplied room state.
|
// loadAuthEvents loads the events needed for authentication from the supplied room state.
|
||||||
func loadAuthEvents(
|
func loadAuthEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db storage.Database,
|
db state.StateResolutionStorage,
|
||||||
needed gomatrixserverlib.StateNeeded,
|
needed gomatrixserverlib.StateNeeded,
|
||||||
state []types.StateEntry,
|
state []types.StateEntry,
|
||||||
) (result authEvents, err error) {
|
) (result authEvents, err error) {
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -38,6 +39,19 @@ import (
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type retryAction int
|
||||||
|
type commitAction int
|
||||||
|
|
||||||
|
const (
|
||||||
|
doNotRetry retryAction = iota
|
||||||
|
retryLater
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
commitTransaction commitAction = iota
|
||||||
|
rollbackTransaction
|
||||||
|
)
|
||||||
|
|
||||||
var keyContentFields = map[string]string{
|
var keyContentFields = map[string]string{
|
||||||
"m.room.join_rules": "join_rule",
|
"m.room.join_rules": "join_rule",
|
||||||
"m.room.history_visibility": "history_visibility",
|
"m.room.history_visibility": "history_visibility",
|
||||||
|
@ -101,7 +115,8 @@ func (r *Inputer) Start() error {
|
||||||
_ = msg.InProgress() // resets the acknowledgement wait timer
|
_ = msg.InProgress() // resets the acknowledgement wait timer
|
||||||
defer eventsInProgress.Delete(index)
|
defer eventsInProgress.Delete(index)
|
||||||
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
|
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
|
||||||
if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil {
|
action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent)
|
||||||
|
if err != nil {
|
||||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
}
|
}
|
||||||
|
@ -111,7 +126,12 @@ func (r *Inputer) Start() error {
|
||||||
"type": inputRoomEvent.Event.Type(),
|
"type": inputRoomEvent.Event.Type(),
|
||||||
}).Warn("Roomserver failed to process async event")
|
}).Warn("Roomserver failed to process async event")
|
||||||
}
|
}
|
||||||
_ = msg.Ack()
|
switch action {
|
||||||
|
case retryLater:
|
||||||
|
_ = msg.Nak()
|
||||||
|
case doNotRetry:
|
||||||
|
_ = msg.Ack()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
// NATS wants to acknowledge automatically by default when the message is
|
// NATS wants to acknowledge automatically by default when the message is
|
||||||
|
@ -131,6 +151,37 @@ func (r *Inputer) Start() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// processRoomEventUsingUpdater opens up a room updater and tries to
|
||||||
|
// process the event. It returns whether or not we should positively
|
||||||
|
// or negatively acknowledge the event (i.e. for NATS) and an error
|
||||||
|
// if it occurred.
|
||||||
|
func (r *Inputer) processRoomEventUsingUpdater(
|
||||||
|
ctx context.Context,
|
||||||
|
roomID string,
|
||||||
|
inputRoomEvent *api.InputRoomEvent,
|
||||||
|
) (retryAction, error) {
|
||||||
|
roomInfo, err := r.DB.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err)
|
||||||
|
}
|
||||||
|
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
|
||||||
|
if err != nil {
|
||||||
|
return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
|
||||||
|
}
|
||||||
|
action, err := r.processRoomEvent(ctx, updater, inputRoomEvent)
|
||||||
|
switch action {
|
||||||
|
case commitTransaction:
|
||||||
|
if cerr := updater.Commit(); cerr != nil {
|
||||||
|
return retryLater, fmt.Errorf("updater.Commit: %w", cerr)
|
||||||
|
}
|
||||||
|
case rollbackTransaction:
|
||||||
|
if rerr := updater.Rollback(); rerr != nil {
|
||||||
|
return retryLater, fmt.Errorf("updater.Rollback: %w", rerr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return doNotRetry, err
|
||||||
|
}
|
||||||
|
|
||||||
// InputRoomEvents implements api.RoomserverInternalAPI
|
// InputRoomEvents implements api.RoomserverInternalAPI
|
||||||
func (r *Inputer) InputRoomEvents(
|
func (r *Inputer) InputRoomEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -177,7 +228,7 @@ func (r *Inputer) InputRoomEvents(
|
||||||
worker.Act(nil, func() {
|
worker.Act(nil, func() {
|
||||||
defer eventsInProgress.Delete(index)
|
defer eventsInProgress.Delete(index)
|
||||||
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
|
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
|
||||||
err := r.processRoomEvent(ctx, &inputRoomEvent)
|
_, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
|
|
|
@ -29,6 +29,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
@ -67,14 +68,15 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (r *Inputer) processRoomEvent(
|
func (r *Inputer) processRoomEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
updater *shared.RoomUpdater,
|
||||||
input *api.InputRoomEvent,
|
input *api.InputRoomEvent,
|
||||||
) (err error) {
|
) (commitAction, error) {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Before we do anything, make sure the context hasn't expired for this pending task.
|
// Before we do anything, make sure the context hasn't expired for this pending task.
|
||||||
// If it has then we'll give up straight away — it's probably a synchronous input
|
// If it has then we'll give up straight away — it's probably a synchronous input
|
||||||
// request and the caller has already given up, but the inbox task was still queued.
|
// request and the caller has already given up, but the inbox task was still queued.
|
||||||
return context.DeadlineExceeded
|
return rollbackTransaction, context.DeadlineExceeded
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,7 +109,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
// if we have already got this event then do not process it again, if the input kind is an outlier.
|
// if we have already got this event then do not process it again, if the input kind is an outlier.
|
||||||
// Outliers contain no extra information which may warrant a re-processing.
|
// Outliers contain no extra information which may warrant a re-processing.
|
||||||
if input.Kind == api.KindOutlier {
|
if input.Kind == api.KindOutlier {
|
||||||
evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()})
|
evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()})
|
||||||
if err2 == nil && len(evs) == 1 {
|
if err2 == nil && len(evs) == 1 {
|
||||||
// check hash matches if we're on early room versions where the event ID was a random string
|
// check hash matches if we're on early room versions where the event ID was a random string
|
||||||
idFormat, err2 := headered.RoomVersion.EventIDFormat()
|
idFormat, err2 := headered.RoomVersion.EventIDFormat()
|
||||||
|
@ -116,11 +118,11 @@ func (r *Inputer) processRoomEvent(
|
||||||
case gomatrixserverlib.EventIDFormatV1:
|
case gomatrixserverlib.EventIDFormatV1:
|
||||||
if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) {
|
if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) {
|
||||||
logger.Debugf("Already processed event; ignoring")
|
logger.Debugf("Already processed event; ignoring")
|
||||||
return nil
|
return rollbackTransaction, nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
logger.Debugf("Already processed event; ignoring")
|
logger.Debugf("Already processed event; ignoring")
|
||||||
return nil
|
return rollbackTransaction, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -134,8 +136,8 @@ func (r *Inputer) processRoomEvent(
|
||||||
AuthEventIDs: event.AuthEventIDs(),
|
AuthEventIDs: event.AuthEventIDs(),
|
||||||
PrevEventIDs: event.PrevEventIDs(),
|
PrevEventIDs: event.PrevEventIDs(),
|
||||||
}
|
}
|
||||||
if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil {
|
if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil {
|
||||||
return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err)
|
return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
missingAuth := len(missingRes.MissingAuthEventIDs) > 0
|
missingAuth := len(missingRes.MissingAuthEventIDs) > 0
|
||||||
|
@ -146,8 +148,8 @@ func (r *Inputer) processRoomEvent(
|
||||||
RoomID: event.RoomID(),
|
RoomID: event.RoomID(),
|
||||||
ExcludeSelf: true,
|
ExcludeSelf: true,
|
||||||
}
|
}
|
||||||
if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
|
if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
|
||||||
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
|
return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
|
||||||
}
|
}
|
||||||
// Sort all of the servers into a map so that we can randomise
|
// Sort all of the servers into a map so that we can randomise
|
||||||
// their order. Then make sure that the input origin and the
|
// their order. Then make sure that the input origin and the
|
||||||
|
@ -176,8 +178,8 @@ func (r *Inputer) processRoomEvent(
|
||||||
isRejected := false
|
isRejected := false
|
||||||
authEvents := gomatrixserverlib.NewAuthEvents(nil)
|
authEvents := gomatrixserverlib.NewAuthEvents(nil)
|
||||||
knownEvents := map[string]*types.Event{}
|
knownEvents := map[string]*types.Event{}
|
||||||
if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
|
if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
|
||||||
return fmt.Errorf("r.fetchAuthEvents: %w", err)
|
return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the event is allowed by its auth events. If it isn't then
|
// Check if the event is allowed by its auth events. If it isn't then
|
||||||
|
@ -193,7 +195,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
authEventNIDs := make([]types.EventNID, 0, len(authEventIDs))
|
authEventNIDs := make([]types.EventNID, 0, len(authEventIDs))
|
||||||
for _, authEventID := range authEventIDs {
|
for _, authEventID := range authEventIDs {
|
||||||
if _, ok := knownEvents[authEventID]; !ok {
|
if _, ok := knownEvents[authEventID]; !ok {
|
||||||
return fmt.Errorf("missing auth event %s", authEventID)
|
return rollbackTransaction, fmt.Errorf("missing auth event %s", authEventID)
|
||||||
}
|
}
|
||||||
authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID)
|
authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID)
|
||||||
}
|
}
|
||||||
|
@ -202,7 +204,8 @@ func (r *Inputer) processRoomEvent(
|
||||||
if input.Kind == api.KindNew {
|
if input.Kind == api.KindNew {
|
||||||
// Check that the event passes authentication checks based on the
|
// Check that the event passes authentication checks based on the
|
||||||
// current room state.
|
// current room state.
|
||||||
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
|
var err error
|
||||||
|
softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Warn("Error authing soft-failed event")
|
logger.WithError(err).Warn("Error authing soft-failed event")
|
||||||
}
|
}
|
||||||
|
@ -227,7 +230,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
origin: input.Origin,
|
origin: input.Origin,
|
||||||
inputer: r,
|
inputer: r,
|
||||||
queryer: r.Queryer,
|
queryer: r.Queryer,
|
||||||
db: r.DB,
|
db: updater,
|
||||||
federation: r.FSAPI,
|
federation: r.FSAPI,
|
||||||
keys: r.KeyRing,
|
keys: r.KeyRing,
|
||||||
roomsMu: internal.NewMutexByRoom(),
|
roomsMu: internal.NewMutexByRoom(),
|
||||||
|
@ -235,7 +238,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
hadEvents: map[string]bool{},
|
hadEvents: map[string]bool{},
|
||||||
haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{},
|
haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{},
|
||||||
}
|
}
|
||||||
if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
|
if err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
|
||||||
isRejected = true
|
isRejected = true
|
||||||
rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err)
|
rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err)
|
||||||
} else {
|
} else {
|
||||||
|
@ -248,16 +251,16 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the event.
|
// Store the event.
|
||||||
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected)
|
_, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.DB.StoreEvent: %w", err)
|
return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if storing this event results in it being redacted then do so.
|
// if storing this event results in it being redacted then do so.
|
||||||
if !isRejected && redactedEventID == event.EventID() {
|
if !isRejected && redactedEventID == event.EventID() {
|
||||||
r, rerr := eventutil.RedactEvent(redactionEvent, event)
|
r, rerr := eventutil.RedactEvent(redactionEvent, event)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return fmt.Errorf("eventutil.RedactEvent: %w", rerr)
|
return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr)
|
||||||
}
|
}
|
||||||
event = r
|
event = r
|
||||||
}
|
}
|
||||||
|
@ -268,23 +271,23 @@ func (r *Inputer) processRoomEvent(
|
||||||
if input.Kind == api.KindOutlier {
|
if input.Kind == api.KindOutlier {
|
||||||
logger.Debug("Stored outlier")
|
logger.Debug("Stored outlier")
|
||||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||||
return nil
|
return commitTransaction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID())
|
roomInfo, err := updater.RoomInfo(ctx, event.RoomID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.DB.RoomInfo: %w", err)
|
return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err)
|
||||||
}
|
}
|
||||||
if roomInfo == nil {
|
if roomInfo == nil {
|
||||||
return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID())
|
return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 {
|
if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 {
|
||||||
// We haven't calculated a state for this event yet.
|
// We haven't calculated a state for this event yet.
|
||||||
// Lets calculate one.
|
// Lets calculate one.
|
||||||
err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected)
|
err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.calculateAndSetState: %w", err)
|
return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -294,13 +297,14 @@ func (r *Inputer) processRoomEvent(
|
||||||
"soft_fail": softfail,
|
"soft_fail": softfail,
|
||||||
"missing_prev": missingPrev,
|
"missing_prev": missingPrev,
|
||||||
}).Warn("Stored rejected event")
|
}).Warn("Stored rejected event")
|
||||||
return rejectionErr
|
return commitTransaction, rejectionErr
|
||||||
}
|
}
|
||||||
|
|
||||||
switch input.Kind {
|
switch input.Kind {
|
||||||
case api.KindNew:
|
case api.KindNew:
|
||||||
if err = r.updateLatestEvents(
|
if err = r.updateLatestEvents(
|
||||||
ctx, // context
|
ctx, // context
|
||||||
|
updater, // room updater
|
||||||
roomInfo, // room info for the room being updated
|
roomInfo, // room info for the room being updated
|
||||||
stateAtEvent, // state at event (below)
|
stateAtEvent, // state at event (below)
|
||||||
event, // event
|
event, // event
|
||||||
|
@ -308,7 +312,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
input.TransactionID, // transaction ID
|
input.TransactionID, // transaction ID
|
||||||
input.HasState, // rewrites state?
|
input.HasState, // rewrites state?
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return fmt.Errorf("r.updateLatestEvents: %w", err)
|
return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err)
|
||||||
}
|
}
|
||||||
case api.KindOld:
|
case api.KindOld:
|
||||||
err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{
|
err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{
|
||||||
|
@ -320,7 +324,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.WriteOutputEvents (old): %w", err)
|
return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -339,14 +343,14 @@ func (r *Inputer) processRoomEvent(
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
|
return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Everything was OK — the latest events updater didn't error and
|
// Everything was OK — the latest events updater didn't error and
|
||||||
// we've sent output events. Finally, generate a hook call.
|
// we've sent output events. Finally, generate a hook call.
|
||||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||||
return nil
|
return commitTransaction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchAuthEvents will check to see if any of the
|
// fetchAuthEvents will check to see if any of the
|
||||||
|
@ -358,6 +362,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
// they are now in the database.
|
// they are now in the database.
|
||||||
func (r *Inputer) fetchAuthEvents(
|
func (r *Inputer) fetchAuthEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
updater *shared.RoomUpdater,
|
||||||
logger *logrus.Entry,
|
logger *logrus.Entry,
|
||||||
event *gomatrixserverlib.HeaderedEvent,
|
event *gomatrixserverlib.HeaderedEvent,
|
||||||
auth *gomatrixserverlib.AuthEvents,
|
auth *gomatrixserverlib.AuthEvents,
|
||||||
|
@ -375,7 +380,7 @@ func (r *Inputer) fetchAuthEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, authEventID := range authEventIDs {
|
for _, authEventID := range authEventIDs {
|
||||||
authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID})
|
authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID})
|
||||||
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
|
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
|
||||||
unknown[authEventID] = struct{}{}
|
unknown[authEventID] = struct{}{}
|
||||||
continue
|
continue
|
||||||
|
@ -454,9 +459,9 @@ func (r *Inputer) fetchAuthEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, store the event in the database.
|
// Finally, store the event in the database.
|
||||||
eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
|
eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.DB.StoreEvent: %w", err)
|
return fmt.Errorf("updater.StoreEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now we know about this event, it was stored and the signatures were OK.
|
// Now we know about this event, it was stored and the signatures were OK.
|
||||||
|
@ -471,6 +476,7 @@ func (r *Inputer) fetchAuthEvents(
|
||||||
|
|
||||||
func (r *Inputer) calculateAndSetState(
|
func (r *Inputer) calculateAndSetState(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
updater *shared.RoomUpdater,
|
||||||
input *api.InputRoomEvent,
|
input *api.InputRoomEvent,
|
||||||
roomInfo *types.RoomInfo,
|
roomInfo *types.RoomInfo,
|
||||||
stateAtEvent *types.StateAtEvent,
|
stateAtEvent *types.StateAtEvent,
|
||||||
|
@ -478,14 +484,14 @@ func (r *Inputer) calculateAndSetState(
|
||||||
isRejected bool,
|
isRejected bool,
|
||||||
) error {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
roomState := state.NewStateResolution(r.DB, roomInfo)
|
roomState := state.NewStateResolution(updater, roomInfo)
|
||||||
|
|
||||||
if input.HasState {
|
if input.HasState {
|
||||||
// Check here if we think we're in the room already.
|
// Check here if we think we're in the room already.
|
||||||
stateAtEvent.Overwrite = true
|
stateAtEvent.Overwrite = true
|
||||||
var joinEventNIDs []types.EventNID
|
var joinEventNIDs []types.EventNID
|
||||||
// Request join memberships only for local users only.
|
// Request join memberships only for local users only.
|
||||||
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
|
if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
|
||||||
// If we have no local users that are joined to the room then any state about
|
// If we have no local users that are joined to the room then any state about
|
||||||
// the room that we have is quite possibly out of date. Therefore in that case
|
// the room that we have is quite possibly out of date. Therefore in that case
|
||||||
// we should overwrite it rather than merge it.
|
// we should overwrite it rather than merge it.
|
||||||
|
@ -495,13 +501,13 @@ func (r *Inputer) calculateAndSetState(
|
||||||
// We've been told what the state at the event is so we don't need to calculate it.
|
// We've been told what the state at the event is so we don't need to calculate it.
|
||||||
// Check that those state events are in the database and store the state.
|
// Check that those state events are in the database and store the state.
|
||||||
var entries []types.StateEntry
|
var entries []types.StateEntry
|
||||||
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
|
if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
|
||||||
return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err)
|
return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
|
||||||
}
|
}
|
||||||
entries = types.DeduplicateStateEntries(entries)
|
entries = types.DeduplicateStateEntries(entries)
|
||||||
|
|
||||||
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
|
if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
|
||||||
return fmt.Errorf("r.DB.AddState: %w", err)
|
return fmt.Errorf("updater.AddState: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
stateAtEvent.Overwrite = false
|
stateAtEvent.Overwrite = false
|
||||||
|
@ -512,7 +518,7 @@ func (r *Inputer) calculateAndSetState(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.DB.SetState: %w", err)
|
return fmt.Errorf("r.DB.SetState: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
|
@ -48,6 +47,7 @@ import (
|
||||||
// Can only be called once at a time
|
// Can only be called once at a time
|
||||||
func (r *Inputer) updateLatestEvents(
|
func (r *Inputer) updateLatestEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
updater *shared.RoomUpdater,
|
||||||
roomInfo *types.RoomInfo,
|
roomInfo *types.RoomInfo,
|
||||||
stateAtEvent types.StateAtEvent,
|
stateAtEvent types.StateAtEvent,
|
||||||
event *gomatrixserverlib.Event,
|
event *gomatrixserverlib.Event,
|
||||||
|
@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents(
|
||||||
transactionID *api.TransactionID,
|
transactionID *api.TransactionID,
|
||||||
rewritesState bool,
|
rewritesState bool,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
|
|
||||||
}
|
|
||||||
succeeded := false
|
|
||||||
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
|
||||||
|
|
||||||
u := latestEventsUpdater{
|
u := latestEventsUpdater{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
api: r,
|
api: r,
|
||||||
|
@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents(
|
||||||
return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
|
return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
succeeded = true
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents(
|
||||||
type latestEventsUpdater struct {
|
type latestEventsUpdater struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
api *Inputer
|
api *Inputer
|
||||||
updater *shared.LatestEventsUpdater
|
updater *shared.RoomUpdater
|
||||||
roomInfo *types.RoomInfo
|
roomInfo *types.RoomInfo
|
||||||
stateAtEvent types.StateAtEvent
|
stateAtEvent types.StateAtEvent
|
||||||
event *gomatrixserverlib.Event
|
event *gomatrixserverlib.Event
|
||||||
|
@ -199,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
|
|
||||||
func (u *latestEventsUpdater) latestState() error {
|
func (u *latestEventsUpdater) latestState() error {
|
||||||
var err error
|
var err error
|
||||||
roomState := state.NewStateResolution(u.api.DB, u.roomInfo)
|
roomState := state.NewStateResolution(u.updater, u.roomInfo)
|
||||||
|
|
||||||
// Work out if the state at the extremities has actually changed
|
// Work out if the state at the extremities has actually changed
|
||||||
// or not. If they haven't then we won't bother doing all of the
|
// or not. If they haven't then we won't bother doing all of the
|
||||||
|
@ -413,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro
|
||||||
if len(extraEventIDs) == 0 {
|
if len(extraEventIDs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs)
|
extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -436,7 +428,7 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error)
|
||||||
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
||||||
}
|
}
|
||||||
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
|
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
|
||||||
return u.api.DB.EventIDs(u.ctx, stateEventNIDs)
|
return u.updater.EventIDs(u.ctx, stateEventNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
type eventNIDSorter []types.EventNID
|
type eventNIDSorter []types.EventNID
|
||||||
|
|
|
@ -31,7 +31,7 @@ import (
|
||||||
// consumers about the invites added or retired by the change in current state.
|
// consumers about the invites added or retired by the change in current state.
|
||||||
func (r *Inputer) updateMemberships(
|
func (r *Inputer) updateMemberships(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
updater *shared.LatestEventsUpdater,
|
updater *shared.RoomUpdater,
|
||||||
removed, added []types.StateEntry,
|
removed, added []types.StateEntry,
|
||||||
) ([]api.OutputEvent, error) {
|
) ([]api.OutputEvent, error) {
|
||||||
changes := membershipChanges(removed, added)
|
changes := membershipChanges(removed, added)
|
||||||
|
@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Inputer) updateMembership(
|
func (r *Inputer) updateMembership(
|
||||||
updater *shared.LatestEventsUpdater,
|
updater *shared.RoomUpdater,
|
||||||
targetUserNID types.EventStateKeyNID,
|
targetUserNID types.EventStateKeyNID,
|
||||||
remove, add *gomatrixserverlib.Event,
|
remove, add *gomatrixserverlib.Event,
|
||||||
updates []api.OutputEvent,
|
updates []api.OutputEvent,
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/internal/query"
|
"github.com/matrix-org/dendrite/roomserver/internal/query"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
@ -19,7 +19,7 @@ import (
|
||||||
|
|
||||||
type missingStateReq struct {
|
type missingStateReq struct {
|
||||||
origin gomatrixserverlib.ServerName
|
origin gomatrixserverlib.ServerName
|
||||||
db storage.Database
|
db *shared.RoomUpdater
|
||||||
inputer *Inputer
|
inputer *Inputer
|
||||||
queryer *query.Queryer
|
queryer *query.Queryer
|
||||||
keys gomatrixserverlib.JSONVerifier
|
keys gomatrixserverlib.JSONVerifier
|
||||||
|
@ -78,7 +78,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
||||||
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
|
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
|
||||||
// in the gap in the DAG
|
// in the gap in the DAG
|
||||||
for _, newEvent := range newEvents {
|
for _, newEvent := range newEvents {
|
||||||
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
|
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
|
||||||
Kind: api.KindNew,
|
Kind: api.KindNew,
|
||||||
Event: newEvent.Headered(roomVersion),
|
Event: newEvent.Headered(roomVersion),
|
||||||
Origin: t.origin,
|
Origin: t.origin,
|
||||||
|
@ -187,7 +187,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
||||||
}
|
}
|
||||||
// TODO: we could do this concurrently?
|
// TODO: we could do this concurrently?
|
||||||
for _, ire := range outlierRoomEvents {
|
for _, ire := range outlierRoomEvents {
|
||||||
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil {
|
if _, err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil {
|
||||||
return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err)
|
return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -200,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
||||||
stateIDs = append(stateIDs, event.EventID())
|
stateIDs = append(stateIDs, event.EventID())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
|
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
|
||||||
Kind: api.KindOld,
|
Kind: api.KindOld,
|
||||||
Event: backwardsExtremity.Headered(roomVersion),
|
Event: backwardsExtremity.Headered(roomVersion),
|
||||||
Origin: t.origin,
|
Origin: t.origin,
|
||||||
|
@ -217,7 +217,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
||||||
// they will automatically fast-forward based on the room state at the
|
// they will automatically fast-forward based on the room state at the
|
||||||
// extremity in the last step.
|
// extremity in the last step.
|
||||||
for _, newEvent := range newEvents {
|
for _, newEvent := range newEvents {
|
||||||
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
|
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
|
||||||
Kind: api.KindOld,
|
Kind: api.KindOld,
|
||||||
Event: newEvent.Headered(roomVersion),
|
Event: newEvent.Headered(roomVersion),
|
||||||
Origin: t.origin,
|
Origin: t.origin,
|
||||||
|
|
|
@ -22,7 +22,6 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
|
||||||
|
@ -30,13 +29,25 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type StateResolutionStorage interface {
|
||||||
|
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||||
|
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||||
|
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||||
|
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||||
|
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||||
|
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
||||||
|
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
||||||
|
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||||
|
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
|
}
|
||||||
|
|
||||||
type StateResolution struct {
|
type StateResolution struct {
|
||||||
db storage.Database
|
db StateResolutionStorage
|
||||||
roomInfo *types.RoomInfo
|
roomInfo *types.RoomInfo
|
||||||
events map[types.EventNID]*gomatrixserverlib.Event
|
events map[types.EventNID]*gomatrixserverlib.Event
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution {
|
func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
|
||||||
return StateResolution{
|
return StateResolution{
|
||||||
db: db,
|
db: db,
|
||||||
roomInfo: roomInfo,
|
roomInfo: roomInfo,
|
||||||
|
|
|
@ -86,11 +86,10 @@ type Database interface {
|
||||||
// Lookup the event IDs for a batch of event numeric IDs.
|
// Lookup the event IDs for a batch of event numeric IDs.
|
||||||
// Returns an error if the retrieval went wrong.
|
// Returns an error if the retrieval went wrong.
|
||||||
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||||
// Look up the latest events in a room in preparation for an update.
|
// Opens and returns a room updater, which locks the room and opens a transaction.
|
||||||
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
// The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
||||||
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
|
||||||
// If this returns an error then no further action is required.
|
// If this returns an error then no further action is required.
|
||||||
GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error)
|
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
|
||||||
// Look up event references for the latest events in the room and the current state snapshot.
|
// Look up event references for the latest events in the room and the current state snapshot.
|
||||||
// Returns the latest events, the current state and the maximum depth of the latest events plus 1.
|
// Returns the latest events, the current state and the maximum depth of the latest events plus 1.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
|
|
@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
) ([]tables.EventJSONPair, error) {
|
) ([]tables.EventJSONPair, error) {
|
||||||
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||||
ctx context.Context, eventStateKeys []string,
|
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||||
) (map[string]types.EventStateKeyNID, error) {
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(
|
||||||
ctx, pq.StringArray(eventStateKeys),
|
ctx, pq.StringArray(eventStateKeys),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||||
) (map[types.EventStateKeyNID]string, error) {
|
) (map[types.EventStateKeyNID]string, error) {
|
||||||
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
|
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
|
||||||
for i := range eventStateKeyNIDs {
|
for i := range eventStateKeyNIDs {
|
||||||
nIDs[i] = int64(eventStateKeyNIDs[i])
|
nIDs[i] = int64(eventStateKeyNIDs[i])
|
||||||
}
|
}
|
||||||
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, nIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||||
ctx context.Context, eventTypes []string,
|
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||||
) (map[string]types.EventTypeNID, error) {
|
) (map[string]types.EventTypeNID, error) {
|
||||||
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent(
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
func (s *eventStatements) BulkSelectStateEventByID(
|
func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
// bulkSelectStateEventByNID lookups a list of state events by event NID.
|
// bulkSelectStateEventByNID lookups a list of state events by event NID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
func (s *eventStatements) BulkSelectStateEventByNID(
|
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||||
sort.Sort(tuples)
|
sort.Sort(tuples)
|
||||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||||
rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||||
func (s *eventStatements) BulkSelectStateAtEventByID(
|
func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
) ([]types.StateAtEvent, error) {
|
) ([]types.StateAtEvent, error) {
|
||||||
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference(
|
||||||
}
|
}
|
||||||
|
|
||||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||||
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
|
||||||
|
|
||||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||||
// If an event ID is not in the database then it is omitted from the map.
|
// If an event ID is not in the database then it is omitted from the map.
|
||||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
|
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
|
||||||
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
) (map[types.EventNID]types.RoomNID, error) {
|
) (map[types.EventNID]types.RoomNID, error) {
|
||||||
rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteStatements) InsertInviteEvent(
|
func (s *inviteStatements) InsertInviteEvent(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
inviteEventID string, roomNID types.RoomNID,
|
||||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||||
inviteEventJSON []byte,
|
inviteEventJSON []byte,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
|
@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteStatements) UpdateInviteRetired(
|
func (s *inviteStatements) UpdateInviteRetired(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||||
|
@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired(
|
||||||
|
|
||||||
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||||
) ([]types.EventStateKeyNID, []string, error) {
|
) ([]types.EventStateKeyNID, []string, error) {
|
||||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
||||||
|
rows, err := stmt.QueryContext(
|
||||||
ctx, targetUserNID, roomNID,
|
ctx, targetUserNID, roomNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) InsertMembership(
|
func (s *membershipStatements) InsertMembership(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
localTarget bool,
|
localTarget bool,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||||
|
@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectMembershipForUpdate(
|
func (s *membershipStatements) SelectMembershipForUpdate(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (membership tables.MembershipState, err error) {
|
) (membership tables.MembershipState, err error) {
|
||||||
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
|
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
|
||||||
ctx, roomNID, targetUserNID,
|
ctx, roomNID, targetUserNID,
|
||||||
|
@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
||||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
|
||||||
|
err = stmt.QueryRowContext(
|
||||||
ctx, roomNID, targetUserNID,
|
ctx, roomNID, targetUserNID,
|
||||||
).Scan(&membership, &eventNID, &forgotten)
|
).Scan(&membership, &eventNID, &forgotten)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectMembershipsFromRoom(
|
func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||||
ctx context.Context, roomNID types.RoomNID, localOnly bool,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
roomNID types.RoomNID, localOnly bool,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
var stmt *sql.Stmt
|
var stmt *sql.Stmt
|
||||||
if localOnly {
|
if localOnly {
|
||||||
|
@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||||
} else {
|
} else {
|
||||||
stmt = s.selectMembershipsFromRoomStmt
|
stmt = s.selectMembershipsFromRoomStmt
|
||||||
}
|
}
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
rows, err := stmt.QueryContext(ctx, roomNID)
|
rows, err := stmt.QueryContext(ctx, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
|
@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||||
} else {
|
} else {
|
||||||
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
||||||
}
|
}
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
rows, err = stmt.QueryContext(ctx, roomNID, membership)
|
rows, err = stmt.QueryContext(ctx, roomNID, membership)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) UpdateMembership(
|
func (s *membershipStatements) UpdateMembership(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||||
eventNID types.EventNID, forgotten bool,
|
eventNID types.EventNID, forgotten bool,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
|
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
|
||||||
|
@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
) ([]types.RoomNID, error) {
|
) ([]types.RoomNID, error) {
|
||||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
return roomNIDs, nil
|
return roomNIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
roomNIDs []types.RoomNID,
|
||||||
|
) (map[types.EventStateKeyNID]int, error) {
|
||||||
roomIDarray := make([]int64, len(roomNIDs))
|
roomIDarray := make([]int64, len(roomNIDs))
|
||||||
for i := range roomNIDs {
|
for i := range roomNIDs {
|
||||||
roomIDarray[i] = int64(roomNIDs[i])
|
roomIDarray[i] = int64(roomNIDs[i])
|
||||||
}
|
}
|
||||||
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
func (s *membershipStatements) SelectKnownUsers(
|
||||||
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
userID types.EventStateKeyNID, searchString string, limit int,
|
||||||
|
) ([]string, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) UpdateForgetMembership(
|
func (s *membershipStatements) UpdateForgetMembership(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool,
|
||||||
forget bool,
|
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
||||||
ctx, roomNID, targetUserNID, forget,
|
ctx, roomNID, targetUserNID, forget,
|
||||||
|
@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
func (s *membershipStatements) SelectLocalServerInRoom(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
) (bool, error) {
|
||||||
var nid types.RoomNID
|
var nid types.RoomNID
|
||||||
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false, nil
|
return false, nil
|
||||||
|
@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
|
||||||
return found, nil
|
return found, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
func (s *membershipStatements) SelectServerInRoom(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
roomNID types.RoomNID, serverName gomatrixserverlib.ServerName,
|
||||||
|
) (bool, error) {
|
||||||
var nid types.RoomNID
|
var nid types.RoomNID
|
||||||
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false, nil
|
return false, nil
|
||||||
|
|
|
@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||||
ctx context.Context, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) (published bool, err error) {
|
) (published bool, err error) {
|
||||||
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
|
stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *publishedStatements) SelectAllPublishedRooms(
|
func (s *publishedStatements) SelectAllPublishedRooms(
|
||||||
ctx context.Context, published bool,
|
ctx context.Context, txn *sql.Tx, published bool,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
|
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, published)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||||
ctx context.Context, alias string,
|
ctx context.Context, txn *sql.Tx, alias string,
|
||||||
) (roomID string, err error) {
|
) (roomID string, err error) {
|
||||||
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||||
ctx context.Context, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||||
ctx context.Context, alias string,
|
ctx context.Context, txn *sql.Tx, alias string,
|
||||||
) (creatorID string, err error) {
|
) (creatorID string, err error) {
|
||||||
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID(
|
||||||
return types.RoomNID(roomNID), err
|
return types.RoomNID(roomNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
||||||
var info types.RoomInfo
|
var info types.RoomInfo
|
||||||
var latestNIDs pq.Int64Array
|
var latestNIDs pq.Int64Array
|
||||||
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
|
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, roomID).Scan(
|
||||||
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
|
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
|
||||||
)
|
)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs(
|
||||||
) ([]types.EventNID, types.StateSnapshotNID, error) {
|
) ([]types.EventNID, types.StateSnapshotNID, error) {
|
||||||
var nids pq.Int64Array
|
var nids pq.Int64Array
|
||||||
var stateSnapshotNID int64
|
var stateSnapshotNID int64
|
||||||
stmt := s.selectLatestEventNIDsStmt
|
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
|
||||||
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
|
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
|
@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||||
ctx context.Context, roomNIDs []types.RoomNID,
|
ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
|
||||||
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
||||||
rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
|
stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
|
||||||
var array pq.Int64Array
|
var array pq.Int64Array
|
||||||
for _, nid := range roomNIDs {
|
for _, nid := range roomNIDs {
|
||||||
array = append(array, int64(nid))
|
array = append(array, int64(nid))
|
||||||
}
|
}
|
||||||
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array)
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, array)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
|
||||||
return roomIDs, nil
|
return roomIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
|
||||||
var array pq.StringArray
|
var array pq.StringArray
|
||||||
for _, roomID := range roomIDs {
|
for _, roomID := range roomIDs {
|
||||||
array = append(array, roomID)
|
array = append(array, roomID)
|
||||||
}
|
}
|
||||||
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array)
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, array)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkInsertStateData(
|
func (s *stateBlockStatements) BulkInsertStateData(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx,
|
|
||||||
entries types.StateEntries,
|
entries types.StateEntries,
|
||||||
) (id types.StateBlockNID, err error) {
|
) (id types.StateBlockNID, err error) {
|
||||||
entries = entries[:util.SortAndUnique(entries)]
|
entries = entries[:util.SortAndUnique(entries)]
|
||||||
|
@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
nids = append(nids, e.EventNID)
|
nids = append(nids, e.EventNID)
|
||||||
}
|
}
|
||||||
err = s.insertStateDataStmt.QueryRowContext(
|
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
|
||||||
|
err = stmt.QueryRowContext(
|
||||||
ctx, nids.Hash(), eventNIDsAsArray(nids),
|
ctx, nids.Hash(), eventNIDsAsArray(nids),
|
||||||
).Scan(&id)
|
).Scan(&id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||||
ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
|
ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
|
||||||
) ([][]types.EventNID, error) {
|
) ([][]types.EventNID, error) {
|
||||||
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
||||||
) ([]types.StateBlockNIDList, error) {
|
) ([]types.StateBlockNIDList, error) {
|
||||||
nids := make([]int64, len(stateNIDs))
|
nids := make([]int64, len(stateNIDs))
|
||||||
for i := range stateNIDs {
|
for i := range stateNIDs {
|
||||||
nids[i] = int64(stateNIDs[i])
|
nids[i] = int64(stateNIDs[i])
|
||||||
}
|
}
|
||||||
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,133 +0,0 @@
|
||||||
package shared
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
type LatestEventsUpdater struct {
|
|
||||||
transaction
|
|
||||||
d *Database
|
|
||||||
roomInfo types.RoomInfo
|
|
||||||
latestEvents []types.StateAtEventAndReference
|
|
||||||
lastEventIDSent string
|
|
||||||
currentStateSnapshotNID types.StateSnapshotNID
|
|
||||||
}
|
|
||||||
|
|
||||||
func rollback(txn *sql.Tx) {
|
|
||||||
if txn == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
txn.Rollback() // nolint: errcheck
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) {
|
|
||||||
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
|
||||||
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
|
|
||||||
if err != nil {
|
|
||||||
rollback(txn)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
|
|
||||||
if err != nil {
|
|
||||||
rollback(txn)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var lastEventIDSent string
|
|
||||||
if lastEventNIDSent != 0 {
|
|
||||||
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
|
|
||||||
if err != nil {
|
|
||||||
rollback(txn)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &LatestEventsUpdater{
|
|
||||||
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoomVersion implements types.RoomRecentEventsUpdater
|
|
||||||
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
|
|
||||||
return u.roomInfo.RoomVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
// LatestEvents implements types.RoomRecentEventsUpdater
|
|
||||||
func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
|
|
||||||
return u.latestEvents
|
|
||||||
}
|
|
||||||
|
|
||||||
// LastEventIDSent implements types.RoomRecentEventsUpdater
|
|
||||||
func (u *LatestEventsUpdater) LastEventIDSent() string {
|
|
||||||
return u.lastEventIDSent
|
|
||||||
}
|
|
||||||
|
|
||||||
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
|
|
||||||
func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
|
||||||
return u.currentStateSnapshotNID
|
|
||||||
}
|
|
||||||
|
|
||||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
|
|
||||||
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
|
||||||
for _, ref := range previousEventReferences {
|
|
||||||
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
|
||||||
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
|
||||||
func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
|
||||||
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
|
||||||
if err == nil {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
|
||||||
func (u *LatestEventsUpdater) SetLatestEvents(
|
|
||||||
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
|
||||||
currentStateSnapshotNID types.StateSnapshotNID,
|
|
||||||
) error {
|
|
||||||
eventNIDs := make([]types.EventNID, len(latest))
|
|
||||||
for i := range latest {
|
|
||||||
eventNIDs[i] = latest[i].EventNID
|
|
||||||
}
|
|
||||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
|
||||||
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
|
|
||||||
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
|
|
||||||
}
|
|
||||||
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
|
|
||||||
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
|
|
||||||
roomInfo.StateSnapshotNID = currentStateSnapshotNID
|
|
||||||
roomInfo.IsStub = false
|
|
||||||
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
|
||||||
func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
|
||||||
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
|
||||||
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
|
||||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
|
||||||
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
|
||||||
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
|
||||||
}
|
|
262
roomserver/storage/shared/room_updater.go
Normal file
262
roomserver/storage/shared/room_updater.go
Normal file
|
@ -0,0 +1,262 @@
|
||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RoomUpdater struct {
|
||||||
|
transaction
|
||||||
|
d *Database
|
||||||
|
roomInfo *types.RoomInfo
|
||||||
|
latestEvents []types.StateAtEventAndReference
|
||||||
|
lastEventIDSent string
|
||||||
|
currentStateSnapshotNID types.StateSnapshotNID
|
||||||
|
}
|
||||||
|
|
||||||
|
func rollback(txn *sql.Tx) {
|
||||||
|
if txn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
txn.Rollback() // nolint: errcheck
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) {
|
||||||
|
// If the roomInfo is nil then that means that the room doesn't exist
|
||||||
|
// yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that
|
||||||
|
// would involve locking a row on the table that doesn't exist. Instead
|
||||||
|
// we will just run with a normal database transaction. It'll either
|
||||||
|
// succeed, processing a create event which creates the room, or it won't.
|
||||||
|
if roomInfo == nil {
|
||||||
|
return &RoomUpdater{
|
||||||
|
transaction{ctx, txn}, d, nil, nil, "", 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
||||||
|
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
|
||||||
|
if err != nil {
|
||||||
|
rollback(txn)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
rollback(txn)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var lastEventIDSent string
|
||||||
|
if lastEventNIDSent != 0 {
|
||||||
|
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
|
||||||
|
if err != nil {
|
||||||
|
rollback(txn)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &RoomUpdater{
|
||||||
|
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements sqlutil.Transaction
|
||||||
|
func (u *RoomUpdater) Commit() error {
|
||||||
|
if u.txn == nil { // SQLite mode probably
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return u.txn.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements sqlutil.Transaction
|
||||||
|
func (u *RoomUpdater) Rollback() error {
|
||||||
|
if u.txn == nil { // SQLite mode probably
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return u.txn.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomVersion implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
|
||||||
|
return u.roomInfo.RoomVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// LatestEvents implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference {
|
||||||
|
return u.latestEvents
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastEventIDSent implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *RoomUpdater) LastEventIDSent() string {
|
||||||
|
return u.lastEventIDSent
|
||||||
|
}
|
||||||
|
|
||||||
|
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
||||||
|
return u.currentStateSnapshotNID
|
||||||
|
}
|
||||||
|
|
||||||
|
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
|
||||||
|
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||||
|
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||||
|
for _, ref := range previousEventReferences {
|
||||||
|
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||||
|
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) Events(
|
||||||
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) ([]types.Event, error) {
|
||||||
|
return u.d.events(ctx, u.txn, eventNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) SnapshotNIDFromEventID(
|
||||||
|
ctx context.Context, eventID string,
|
||||||
|
) (types.StateSnapshotNID, error) {
|
||||||
|
return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) StoreEvent(
|
||||||
|
ctx context.Context, event *gomatrixserverlib.Event,
|
||||||
|
authEventNIDs []types.EventNID, isRejected bool,
|
||||||
|
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
|
||||||
|
return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) StateBlockNIDs(
|
||||||
|
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||||
|
) ([]types.StateBlockNIDList, error) {
|
||||||
|
return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) StateEntries(
|
||||||
|
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||||
|
) ([]types.StateEntryList, error) {
|
||||||
|
return u.d.stateEntries(ctx, u.txn, stateBlockNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) StateEntriesForTuples(
|
||||||
|
ctx context.Context,
|
||||||
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
|
) ([]types.StateEntryList, error) {
|
||||||
|
return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) AddState(
|
||||||
|
ctx context.Context,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
|
state []types.StateEntry,
|
||||||
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
|
return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) SetState(
|
||||||
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
|
) error {
|
||||||
|
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||||
|
return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) EventTypeNIDs(
|
||||||
|
ctx context.Context, eventTypes []string,
|
||||||
|
) (map[string]types.EventTypeNID, error) {
|
||||||
|
return u.d.eventTypeNIDs(ctx, u.txn, eventTypes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) EventStateKeyNIDs(
|
||||||
|
ctx context.Context, eventStateKeys []string,
|
||||||
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
|
return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||||
|
return u.d.roomInfo(ctx, u.txn, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) EventIDs(
|
||||||
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) (map[types.EventNID]string, error) {
|
||||||
|
return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) StateAtEventIDs(
|
||||||
|
ctx context.Context, eventIDs []string,
|
||||||
|
) ([]types.StateAtEvent, error) {
|
||||||
|
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) StateEntriesForEventIDs(
|
||||||
|
ctx context.Context, eventIDs []string,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
|
||||||
|
return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
|
||||||
|
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
||||||
|
) ([]types.EventNID, error) {
|
||||||
|
return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
||||||
|
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
||||||
|
if err == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *RoomUpdater) SetLatestEvents(
|
||||||
|
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
||||||
|
currentStateSnapshotNID types.StateSnapshotNID,
|
||||||
|
) error {
|
||||||
|
eventNIDs := make([]types.EventNID, len(latest))
|
||||||
|
for i := range latest {
|
||||||
|
eventNIDs[i] = latest[i].EventNID
|
||||||
|
}
|
||||||
|
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||||
|
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
|
||||||
|
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
|
||||||
|
}
|
||||||
|
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
|
||||||
|
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
|
||||||
|
roomInfo.StateSnapshotNID = currentStateSnapshotNID
|
||||||
|
roomInfo.IsStub = false
|
||||||
|
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
||||||
|
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
||||||
|
func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||||
|
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||||
|
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
||||||
|
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
||||||
|
}
|
|
@ -26,23 +26,23 @@ import (
|
||||||
const redactionsArePermanent = true
|
const redactionsArePermanent = true
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
Cache caching.RoomServerCaches
|
Cache caching.RoomServerCaches
|
||||||
Writer sqlutil.Writer
|
Writer sqlutil.Writer
|
||||||
EventsTable tables.Events
|
EventsTable tables.Events
|
||||||
EventJSONTable tables.EventJSON
|
EventJSONTable tables.EventJSON
|
||||||
EventTypesTable tables.EventTypes
|
EventTypesTable tables.EventTypes
|
||||||
EventStateKeysTable tables.EventStateKeys
|
EventStateKeysTable tables.EventStateKeys
|
||||||
RoomsTable tables.Rooms
|
RoomsTable tables.Rooms
|
||||||
StateSnapshotTable tables.StateSnapshot
|
StateSnapshotTable tables.StateSnapshot
|
||||||
StateBlockTable tables.StateBlock
|
StateBlockTable tables.StateBlock
|
||||||
RoomAliasesTable tables.RoomAliases
|
RoomAliasesTable tables.RoomAliases
|
||||||
PrevEventsTable tables.PreviousEvents
|
PrevEventsTable tables.PreviousEvents
|
||||||
InvitesTable tables.Invites
|
InvitesTable tables.Invites
|
||||||
MembershipTable tables.Membership
|
MembershipTable tables.Membership
|
||||||
PublishedTable tables.Published
|
PublishedTable tables.Published
|
||||||
RedactionsTable tables.Redactions
|
RedactionsTable tables.Redactions
|
||||||
GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error)
|
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SupportsConcurrentRoomInputs() bool {
|
func (d *Database) SupportsConcurrentRoomInputs() bool {
|
||||||
|
@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
|
||||||
|
|
||||||
func (d *Database) EventTypeNIDs(
|
func (d *Database) EventTypeNIDs(
|
||||||
ctx context.Context, eventTypes []string,
|
ctx context.Context, eventTypes []string,
|
||||||
|
) (map[string]types.EventTypeNID, error) {
|
||||||
|
return d.eventTypeNIDs(ctx, nil, eventTypes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) eventTypeNIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||||
) (map[string]types.EventTypeNID, error) {
|
) (map[string]types.EventTypeNID, error) {
|
||||||
result := make(map[string]types.EventTypeNID)
|
result := make(map[string]types.EventTypeNID)
|
||||||
remaining := []string{}
|
remaining := []string{}
|
||||||
|
@ -62,7 +68,7 @@ func (d *Database) EventTypeNIDs(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(remaining) > 0 {
|
if len(remaining) > 0 {
|
||||||
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining)
|
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -77,11 +83,17 @@ func (d *Database) EventTypeNIDs(
|
||||||
func (d *Database) EventStateKeys(
|
func (d *Database) EventStateKeys(
|
||||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||||
) (map[types.EventStateKeyNID]string, error) {
|
) (map[types.EventStateKeyNID]string, error) {
|
||||||
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs)
|
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) EventStateKeyNIDs(
|
func (d *Database) EventStateKeyNIDs(
|
||||||
ctx context.Context, eventStateKeys []string,
|
ctx context.Context, eventStateKeys []string,
|
||||||
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
|
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) eventStateKeyNIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||||
) (map[string]types.EventStateKeyNID, error) {
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
result := make(map[string]types.EventStateKeyNID)
|
result := make(map[string]types.EventStateKeyNID)
|
||||||
remaining := []string{}
|
remaining := []string{}
|
||||||
|
@ -93,7 +105,7 @@ func (d *Database) EventStateKeyNIDs(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(remaining) > 0 {
|
if len(remaining) > 0 {
|
||||||
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining)
|
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -108,23 +120,31 @@ func (d *Database) EventStateKeyNIDs(
|
||||||
func (d *Database) StateEntriesForEventIDs(
|
func (d *Database) StateEntriesForEventIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
|
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateEntriesForTuples(
|
func (d *Database) StateEntriesForTuples(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
|
) ([]types.StateEntryList, error) {
|
||||||
|
return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) stateEntriesForTuples(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntryList, error) {
|
) ([]types.StateEntryList, error) {
|
||||||
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
|
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
|
||||||
ctx, stateBlockNIDs,
|
ctx, txn, stateBlockNIDs,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
|
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
|
||||||
}
|
}
|
||||||
lists := []types.StateEntryList{}
|
lists := []types.StateEntryList{}
|
||||||
for i, entry := range entries {
|
for i, entry := range entries {
|
||||||
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples)
|
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
|
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -137,10 +157,14 @@ func (d *Database) StateEntriesForTuples(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||||
|
return d.roomInfo(ctx, nil, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
||||||
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
|
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
|
||||||
return &roomInfo, nil
|
return &roomInfo, nil
|
||||||
}
|
}
|
||||||
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID)
|
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
|
||||||
if err == nil && roomInfo != nil {
|
if err == nil && roomInfo != nil {
|
||||||
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
|
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
|
||||||
d.Cache.StoreRoomInfo(roomID, *roomInfo)
|
d.Cache.StoreRoomInfo(roomID, *roomInfo)
|
||||||
|
@ -153,13 +177,22 @@ func (d *Database) AddState(
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
state []types.StateEntry,
|
state []types.StateEntry,
|
||||||
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
|
return d.addState(ctx, nil, roomNID, stateBlockNIDs, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) addState(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
|
state []types.StateEntry,
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
if len(stateBlockNIDs) > 0 && len(state) > 0 {
|
if len(stateBlockNIDs) > 0 && len(state) > 0 {
|
||||||
// Check to see if the event already appears in any of the existing state
|
// Check to see if the event already appears in any of the existing state
|
||||||
// blocks. If it does then we should not add it again, as this will just
|
// blocks. If it does then we should not add it again, as this will just
|
||||||
// result in excess state blocks and snapshots.
|
// result in excess state blocks and snapshots.
|
||||||
// TODO: Investigate why this is happening - probably input_events.go!
|
// TODO: Investigate why this is happening - probably input_events.go!
|
||||||
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
|
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
|
||||||
if berr != nil {
|
if berr != nil {
|
||||||
return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
|
return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
|
||||||
}
|
}
|
||||||
|
@ -180,7 +213,7 @@ func (d *Database) AddState(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||||
if len(state) > 0 {
|
if len(state) > 0 {
|
||||||
// If there's any state left to add then let's add new blocks.
|
// If there's any state left to add then let's add new blocks.
|
||||||
var stateBlockNID types.StateBlockNID
|
var stateBlockNID types.StateBlockNID
|
||||||
|
@ -205,7 +238,13 @@ func (d *Database) AddState(
|
||||||
func (d *Database) EventNIDs(
|
func (d *Database) EventNIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string,
|
||||||
) (map[string]types.EventNID, error) {
|
) (map[string]types.EventNID, error) {
|
||||||
return d.EventsTable.BulkSelectEventNID(ctx, eventIDs)
|
return d.eventNIDs(ctx, nil, eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) eventNIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) (map[string]types.EventNID, error) {
|
||||||
|
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SetState(
|
func (d *Database) SetState(
|
||||||
|
@ -219,24 +258,34 @@ func (d *Database) SetState(
|
||||||
func (d *Database) StateAtEventIDs(
|
func (d *Database) StateAtEventIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string,
|
||||||
) ([]types.StateAtEvent, error) {
|
) ([]types.StateAtEvent, error) {
|
||||||
return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs)
|
return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SnapshotNIDFromEventID(
|
func (d *Database) SnapshotNIDFromEventID(
|
||||||
ctx context.Context, eventID string,
|
ctx context.Context, eventID string,
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
_, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID)
|
return d.snapshotNIDFromEventID(ctx, nil, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) snapshotNIDFromEventID(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventID string,
|
||||||
|
) (types.StateSnapshotNID, error) {
|
||||||
|
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
|
||||||
return stateNID, err
|
return stateNID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) EventIDs(
|
func (d *Database) EventIDs(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
) (map[types.EventNID]string, error) {
|
) (map[types.EventNID]string, error) {
|
||||||
return d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
|
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
|
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
|
||||||
nidMap, err := d.EventNIDs(ctx, eventIDs)
|
return d.eventsFromIDs(ctx, nil, eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
|
||||||
|
nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -246,7 +295,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type
|
||||||
nids = append(nids, nid)
|
nids = append(nids, nid)
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.Events(ctx, nids)
|
return d.events(ctx, txn, nids)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) LatestEventIDs(
|
func (d *Database) LatestEventIDs(
|
||||||
|
@ -271,21 +320,33 @@ func (d *Database) LatestEventIDs(
|
||||||
func (d *Database) StateBlockNIDs(
|
func (d *Database) StateBlockNIDs(
|
||||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||||
) ([]types.StateBlockNIDList, error) {
|
) ([]types.StateBlockNIDList, error) {
|
||||||
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs)
|
return d.stateBlockNIDs(ctx, nil, stateNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) stateBlockNIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
||||||
|
) ([]types.StateBlockNIDList, error) {
|
||||||
|
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateEntries(
|
func (d *Database) StateEntries(
|
||||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||||
|
) ([]types.StateEntryList, error) {
|
||||||
|
return d.stateEntries(ctx, nil, stateBlockNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) stateEntries(
|
||||||
|
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
|
||||||
) ([]types.StateEntryList, error) {
|
) ([]types.StateEntryList, error) {
|
||||||
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
|
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
|
||||||
ctx, stateBlockNIDs,
|
ctx, txn, stateBlockNIDs,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
|
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
|
||||||
}
|
}
|
||||||
lists := make([]types.StateEntryList, 0, len(entries))
|
lists := make([]types.StateEntryList, 0, len(entries))
|
||||||
for i, entry := range entries {
|
for i, entry := range entries {
|
||||||
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil)
|
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
|
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -304,17 +365,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
|
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
|
||||||
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias)
|
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
|
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
|
||||||
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID)
|
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetCreatorIDForAlias(
|
func (d *Database) GetCreatorIDForAlias(
|
||||||
ctx context.Context, alias string,
|
ctx context.Context, alias string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias)
|
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
||||||
|
@ -335,7 +396,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
|
||||||
|
|
||||||
senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
|
senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
|
||||||
d.MembershipTable.SelectMembershipFromRoomAndTarget(
|
d.MembershipTable.SelectMembershipFromRoomAndTarget(
|
||||||
ctx, roomNID, requestSenderUserNID,
|
ctx, nil, roomNID, requestSenderUserNID,
|
||||||
)
|
)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// The user has never been a member of that room
|
// The user has never been a member of that room
|
||||||
|
@ -349,14 +410,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
|
||||||
|
|
||||||
func (d *Database) GetMembershipEventNIDsForRoom(
|
func (d *Database) GetMembershipEventNIDsForRoom(
|
||||||
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
||||||
|
) ([]types.EventNID, error) {
|
||||||
|
return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) getMembershipEventNIDsForRoom(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
||||||
) ([]types.EventNID, error) {
|
) ([]types.EventNID, error) {
|
||||||
if joinOnly {
|
if joinOnly {
|
||||||
return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
|
return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
|
||||||
ctx, roomNID, tables.MembershipStateJoin, localOnly,
|
ctx, txn, roomNID, tables.MembershipStateJoin, localOnly,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly)
|
return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetInvitesForUser(
|
func (d *Database) GetInvitesForUser(
|
||||||
|
@ -364,22 +431,28 @@ func (d *Database) GetInvitesForUser(
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
targetUserNID types.EventStateKeyNID,
|
targetUserNID types.EventStateKeyNID,
|
||||||
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
|
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
|
||||||
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
|
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) Events(
|
func (d *Database) Events(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
) ([]types.Event, error) {
|
) ([]types.Event, error) {
|
||||||
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
|
return d.events(ctx, nil, eventNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) events(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
|
) ([]types.Event, error) {
|
||||||
|
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
|
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
eventIDs = map[types.EventNID]string{}
|
eventIDs = map[types.EventNID]string{}
|
||||||
}
|
}
|
||||||
var roomNIDs map[types.EventNID]types.RoomNID
|
var roomNIDs map[types.EventNID]types.RoomNID
|
||||||
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs)
|
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -398,7 +471,7 @@ func (d *Database) Events(
|
||||||
}
|
}
|
||||||
fetchNIDList = append(fetchNIDList, n)
|
fetchNIDList = append(fetchNIDList, n)
|
||||||
}
|
}
|
||||||
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList)
|
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -440,19 +513,19 @@ func (d *Database) MembershipUpdater(
|
||||||
return updater, err
|
return updater, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetLatestEventsForUpdate(
|
func (d *Database) GetRoomUpdater(
|
||||||
ctx context.Context, roomInfo types.RoomInfo,
|
ctx context.Context, roomInfo *types.RoomInfo,
|
||||||
) (*LatestEventsUpdater, error) {
|
) (*RoomUpdater, error) {
|
||||||
if d.GetLatestEventsForUpdateFn != nil {
|
if d.GetRoomUpdaterFn != nil {
|
||||||
return d.GetLatestEventsForUpdateFn(ctx, roomInfo)
|
return d.GetRoomUpdaterFn(ctx, roomInfo)
|
||||||
}
|
}
|
||||||
txn, err := d.DB.Begin()
|
txn, err := d.DB.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var updater *LatestEventsUpdater
|
var updater *RoomUpdater
|
||||||
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||||
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo)
|
updater, err = NewRoomUpdater(ctx, d, txn, roomInfo)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return updater, err
|
return updater, err
|
||||||
|
@ -461,6 +534,13 @@ func (d *Database) GetLatestEventsForUpdate(
|
||||||
func (d *Database) StoreEvent(
|
func (d *Database) StoreEvent(
|
||||||
ctx context.Context, event *gomatrixserverlib.Event,
|
ctx context.Context, event *gomatrixserverlib.Event,
|
||||||
authEventNIDs []types.EventNID, isRejected bool,
|
authEventNIDs []types.EventNID, isRejected bool,
|
||||||
|
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
|
||||||
|
return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) storeEvent(
|
||||||
|
ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event,
|
||||||
|
authEventNIDs []types.EventNID, isRejected bool,
|
||||||
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
|
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
|
||||||
var (
|
var (
|
||||||
roomNID types.RoomNID
|
roomNID types.RoomNID
|
||||||
|
@ -472,8 +552,11 @@ func (d *Database) StoreEvent(
|
||||||
redactedEventID string
|
redactedEventID string
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
var txn *sql.Tx
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
if updater != nil {
|
||||||
|
txn = updater.txn
|
||||||
|
}
|
||||||
|
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||||
// TODO: Here we should aim to have two different code paths for new rooms
|
// TODO: Here we should aim to have two different code paths for new rooms
|
||||||
// vs existing ones.
|
// vs existing ones.
|
||||||
|
|
||||||
|
@ -546,42 +629,32 @@ func (d *Database) StoreEvent(
|
||||||
// events updater because it somewhat works as a mutex, ensuring
|
// events updater because it somewhat works as a mutex, ensuring
|
||||||
// that there's a row-level lock on the latest room events (well,
|
// that there's a row-level lock on the latest room events (well,
|
||||||
// on Postgres at least).
|
// on Postgres at least).
|
||||||
var roomInfo *types.RoomInfo
|
|
||||||
var updater *LatestEventsUpdater
|
|
||||||
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
|
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
|
||||||
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
|
|
||||||
}
|
|
||||||
if roomInfo == nil && len(prevEvents) > 0 {
|
|
||||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
|
|
||||||
}
|
|
||||||
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
|
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
|
||||||
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
|
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
|
||||||
// function only does SELECTs though so the created txn (at this point) is just a read txn like
|
// function only does SELECTs though so the created txn (at this point) is just a read txn like
|
||||||
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
|
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
|
||||||
// to do writes however then this will need to go inside `Writer.Do`.
|
// to do writes however then this will need to go inside `Writer.Do`.
|
||||||
updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo)
|
succeeded := false
|
||||||
if err != nil {
|
if updater == nil {
|
||||||
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err)
|
var roomInfo *types.RoomInfo
|
||||||
}
|
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
|
||||||
// Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents
|
if err != nil {
|
||||||
// and EndTransaction in a writer then it's possible for a new write txn to be made between the two
|
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
|
||||||
// function calls which will then fail with 'database is locked'. This new write txn would HAVE to be
|
|
||||||
// something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to
|
|
||||||
// SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases
|
|
||||||
// as they don't go via InputRoomEvents
|
|
||||||
err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error {
|
|
||||||
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
|
|
||||||
return fmt.Errorf("updater.StorePreviousEvents: %w", err)
|
|
||||||
}
|
}
|
||||||
succeeded := true
|
if roomInfo == nil && len(prevEvents) > 0 {
|
||||||
err = sqlutil.EndTransaction(updater, &succeeded)
|
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
|
||||||
return err
|
}
|
||||||
})
|
updater, err = d.GetRoomUpdater(ctx, roomInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, types.StateAtEvent{}, nil, "", err
|
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
|
||||||
|
}
|
||||||
|
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
||||||
}
|
}
|
||||||
|
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
|
||||||
|
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
|
||||||
|
}
|
||||||
|
succeeded = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return eventNID, roomNID, types.StateAtEvent{
|
return eventNID, roomNID, types.StateAtEvent{
|
||||||
|
@ -603,7 +676,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
|
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
|
||||||
return d.PublishedTable.SelectAllPublishedRooms(ctx, true)
|
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) assignRoomNID(
|
func (d *Database) assignRoomNID(
|
||||||
|
@ -875,14 +948,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
|
||||||
eventNIDs = append(eventNIDs, e.EventNID)
|
eventNIDs = append(eventNIDs, e.EventNID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
|
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
eventIDs = map[types.EventNID]string{}
|
eventIDs = map[types.EventNID]string{}
|
||||||
}
|
}
|
||||||
// return the event requested
|
// return the event requested
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
|
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
|
||||||
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID})
|
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -922,11 +995,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
||||||
}
|
}
|
||||||
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState)
|
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
|
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
|
||||||
}
|
}
|
||||||
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs)
|
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
|
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -945,7 +1018,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
||||||
}
|
}
|
||||||
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
|
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
|
||||||
// isn't a failure.
|
// isn't a failure.
|
||||||
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes)
|
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
|
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -965,7 +1038,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys)
|
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
|
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -999,11 +1072,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
|
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
eventIDs = map[types.EventNID]string{}
|
eventIDs = map[types.EventNID]string{}
|
||||||
}
|
}
|
||||||
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
|
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
|
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -1027,11 +1100,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
||||||
|
|
||||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs)
|
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs)
|
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1041,7 +1114,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
|
||||||
stateKeyNIDs[i] = nid
|
stateKeyNIDs[i] = nid
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
|
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1057,12 +1130,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
|
||||||
|
|
||||||
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
||||||
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
||||||
return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID)
|
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
||||||
func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||||
return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName)
|
return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKnownUsers searches all users that userID knows about.
|
// GetKnownUsers searches all users that userID knows about.
|
||||||
|
@ -1071,17 +1144,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit)
|
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||||
return d.RoomsTable.SelectRoomIDs(ctx)
|
return d.RoomsTable.SelectRoomIDs(ctx, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForgetRoom sets a users room to forgotten
|
// ForgetRoom sets a users room to forgotten
|
||||||
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
|
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
|
||||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID})
|
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
) ([]tables.EventJSONPair, error) {
|
) ([]tables.EventJSONPair, error) {
|
||||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
for k, v := range eventNIDs {
|
for k, v := range eventNIDs {
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||||
|
var rows *sql.Rows
|
||||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
|
var err error
|
||||||
|
if txn != nil {
|
||||||
|
rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||||
|
} else {
|
||||||
|
rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||||
ctx context.Context, eventStateKeys []string,
|
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||||
) (map[string]types.EventStateKeyNID, error) {
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
iEventStateKeys := make([]interface{}, len(eventStateKeys))
|
iEventStateKeys := make([]interface{}, len(eventStateKeys))
|
||||||
for k, v := range eventStateKeys {
|
for k, v := range eventStateKeys {
|
||||||
iEventStateKeys[k] = v
|
iEventStateKeys[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
|
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
|
||||||
|
var rows *sql.Rows
|
||||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
var err error
|
||||||
|
if txn != nil {
|
||||||
|
rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||||
|
} else {
|
||||||
|
rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||||
) (map[types.EventStateKeyNID]string, error) {
|
) (map[types.EventStateKeyNID]string, error) {
|
||||||
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
||||||
for k, v := range eventStateKeyNIDs {
|
for k, v := range eventStateKeyNIDs {
|
||||||
iEventStateKeyNIDs[k] = v
|
iEventStateKeyNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||||
|
rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||||
ctx context.Context, eventTypes []string,
|
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||||
) (map[string]types.EventTypeNID, error) {
|
) (map[string]types.EventTypeNID, error) {
|
||||||
///////////////
|
///////////////
|
||||||
iEventTypes := make([]interface{}, len(eventTypes))
|
iEventTypes := make([]interface{}, len(eventTypes))
|
||||||
|
@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||||
///////////////
|
///////////////
|
||||||
|
|
||||||
rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
|
rows, err := stmt.QueryContext(ctx, iEventTypes...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent(
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
func (s *eventStatements) BulkSelectStateEventByID(
|
func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
///////////////
|
///////////////
|
||||||
iEventIDs := make([]interface{}, len(eventIDs))
|
iEventIDs := make([]interface{}, len(eventIDs))
|
||||||
|
@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||||
///////////////
|
///////////////
|
||||||
|
|
||||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||||
|
@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
func (s *eventStatements) BulkSelectStateEventByNID(
|
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||||
|
@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("s.db.Prepare: %w", err)
|
return nil, fmt.Errorf("s.db.Prepare: %w", err)
|
||||||
}
|
}
|
||||||
|
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||||
rows, err := selectStmt.QueryContext(ctx, params...)
|
rows, err := selectStmt.QueryContext(ctx, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
|
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
|
||||||
|
@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||||
func (s *eventStatements) BulkSelectStateAtEventByID(
|
func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
) ([]types.StateAtEvent, error) {
|
) ([]types.StateAtEvent, error) {
|
||||||
///////////////
|
///////////////
|
||||||
iEventIDs := make([]interface{}, len(eventIDs))
|
iEventIDs := make([]interface{}, len(eventIDs))
|
||||||
|
@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||||
///////////////
|
///////////////
|
||||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
selectPrep = sqlutil.TxStmt(txn, selectPrep)
|
||||||
//////////////
|
//////////////
|
||||||
|
|
||||||
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
||||||
|
@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference(
|
||||||
}
|
}
|
||||||
|
|
||||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||||
///////////////
|
///////////////
|
||||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
for k, v := range eventNIDs {
|
for k, v := range eventNIDs {
|
||||||
|
@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||||
///////////////
|
///////////////
|
||||||
|
|
||||||
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||||
|
@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
|
||||||
|
|
||||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||||
// If an event ID is not in the database then it is omitted from the map.
|
// If an event ID is not in the database then it is omitted from the map.
|
||||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
|
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
|
||||||
///////////////
|
///////////////
|
||||||
iEventIDs := make([]interface{}, len(eventIDs))
|
iEventIDs := make([]interface{}, len(eventIDs))
|
||||||
for k, v := range eventIDs {
|
for k, v := range eventIDs {
|
||||||
|
@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||||
///////////////
|
///////////////
|
||||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
) (map[types.EventNID]types.RoomNID, error) {
|
) (map[types.EventNID]types.RoomNID, error) {
|
||||||
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
||||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
|
||||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
for i, v := range eventNIDs {
|
for i, v := range eventNIDs {
|
||||||
iEventNIDs[i] = v
|
iEventNIDs[i] = v
|
||||||
|
|
|
@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteStatements) InsertInviteEvent(
|
func (s *inviteStatements) InsertInviteEvent(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
inviteEventID string, roomNID types.RoomNID,
|
||||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||||
inviteEventJSON []byte,
|
inviteEventJSON []byte,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
|
@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteStatements) UpdateInviteRetired(
|
func (s *inviteStatements) UpdateInviteRetired(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (eventIDs []string, err error) {
|
) (eventIDs []string, err error) {
|
||||||
// gather all the event IDs we will retire
|
// gather all the event IDs we will retire
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||||
|
@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired(
|
||||||
|
|
||||||
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||||
) ([]types.EventStateKeyNID, []string, error) {
|
) ([]types.EventStateKeyNID, []string, error) {
|
||||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
||||||
|
rows, err := stmt.QueryContext(
|
||||||
ctx, targetUserNID, roomNID,
|
ctx, targetUserNID, roomNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
||||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
|
||||||
|
err = stmt.QueryRowContext(
|
||||||
ctx, roomNID, targetUserNID,
|
ctx, roomNID, targetUserNID,
|
||||||
).Scan(&membership, &eventNID, &forgotten)
|
).Scan(&membership, &eventNID, &forgotten)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectMembershipsFromRoom(
|
func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNID types.RoomNID, localOnly bool,
|
roomNID types.RoomNID, localOnly bool,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
var selectStmt *sql.Stmt
|
var selectStmt *sql.Stmt
|
||||||
|
@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||||
} else {
|
} else {
|
||||||
selectStmt = s.selectMembershipsFromRoomStmt
|
selectStmt = s.selectMembershipsFromRoomStmt
|
||||||
}
|
}
|
||||||
|
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||||
rows, err := selectStmt.QueryContext(ctx, roomNID)
|
rows, err := selectStmt.QueryContext(ctx, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
var stmt *sql.Stmt
|
var stmt *sql.Stmt
|
||||||
|
@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||||
} else {
|
} else {
|
||||||
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
||||||
}
|
}
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
) ([]types.RoomNID, error) {
|
) ([]types.RoomNID, error) {
|
||||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
return roomNIDs, nil
|
return roomNIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||||
for i, v := range roomNIDs {
|
for i, v := range roomNIDs {
|
||||||
iRoomNIDs[i] = v
|
iRoomNIDs[i] = v
|
||||||
}
|
}
|
||||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
||||||
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if txn != nil {
|
||||||
|
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
|
||||||
|
} else {
|
||||||
|
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||||
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) UpdateForgetMembership(
|
func (s *membershipStatements) UpdateForgetMembership(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
forget bool,
|
forget bool,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
||||||
|
@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) {
|
||||||
var nid types.RoomNID
|
var nid types.RoomNID
|
||||||
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false, nil
|
return false, nil
|
||||||
|
@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
|
||||||
return found, nil
|
return found, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||||
var nid types.RoomNID
|
var nid types.RoomNID
|
||||||
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false, nil
|
return false, nil
|
||||||
|
|
|
@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||||
ctx context.Context, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) (published bool, err error) {
|
) (published bool, err error) {
|
||||||
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
|
stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *publishedStatements) SelectAllPublishedRooms(
|
func (s *publishedStatements) SelectAllPublishedRooms(
|
||||||
ctx context.Context, published bool,
|
ctx context.Context, txn *sql.Tx, published bool,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
|
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, published)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||||
ctx context.Context, alias string,
|
ctx context.Context, txn *sql.Tx, alias string,
|
||||||
) (roomID string, err error) {
|
) (roomID string, err error) {
|
||||||
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||||
ctx context.Context, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) (aliases []string, err error) {
|
) (aliases []string, err error) {
|
||||||
aliases = []string{}
|
aliases = []string{}
|
||||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||||
ctx context.Context, alias string,
|
ctx context.Context, txn *sql.Tx, alias string,
|
||||||
) (creatorID string, err error) {
|
) (creatorID string, err error) {
|
||||||
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||||
return roomIDs, nil
|
return roomIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
||||||
var info types.RoomInfo
|
var info types.RoomInfo
|
||||||
var latestNIDsJSON string
|
var latestNIDsJSON string
|
||||||
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
|
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, roomID).Scan(
|
||||||
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
|
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||||
ctx context.Context, roomNIDs []types.RoomNID,
|
ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
|
||||||
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
||||||
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
|
||||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||||
for i, v := range roomNIDs {
|
for i, v := range roomNIDs {
|
||||||
iRoomNIDs[i] = v
|
iRoomNIDs[i] = v
|
||||||
|
@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
|
||||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||||
for i, v := range roomNIDs {
|
for i, v := range roomNIDs {
|
||||||
iRoomNIDs[i] = v
|
iRoomNIDs[i] = v
|
||||||
}
|
}
|
||||||
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if txn != nil {
|
||||||
|
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||||
|
} else {
|
||||||
|
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
|
||||||
return roomIDs, nil
|
return roomIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
|
||||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||||
for i, v := range roomIDs {
|
for i, v := range roomIDs {
|
||||||
iRoomIDs[i] = v
|
iRoomIDs[i] = v
|
||||||
}
|
}
|
||||||
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
|
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
|
||||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if txn != nil {
|
||||||
|
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||||
|
} else {
|
||||||
|
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkInsertStateData(
|
func (s *stateBlockStatements) BulkInsertStateData(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
txn *sql.Tx,
|
|
||||||
entries types.StateEntries,
|
entries types.StateEntries,
|
||||||
) (id types.StateBlockNID, err error) {
|
) (id types.StateBlockNID, err error) {
|
||||||
entries = entries[:util.SortAndUnique(entries)]
|
entries = entries[:util.SortAndUnique(entries)]
|
||||||
|
@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("json.Marshal: %w", err)
|
return 0, fmt.Errorf("json.Marshal: %w", err)
|
||||||
}
|
}
|
||||||
err = s.insertStateDataStmt.QueryRowContext(
|
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
|
||||||
|
err = stmt.QueryRowContext(
|
||||||
ctx, nids.Hash(), js,
|
ctx, nids.Hash(), js,
|
||||||
).Scan(&id)
|
).Scan(&id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||||
ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
|
ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
|
||||||
) ([][]types.EventNID, error) {
|
) ([][]types.EventNID, error) {
|
||||||
intfs := make([]interface{}, len(stateBlockNIDs))
|
intfs := make([]interface{}, len(stateBlockNIDs))
|
||||||
for i := range stateBlockNIDs {
|
for i := range stateBlockNIDs {
|
||||||
|
@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||||
rows, err := selectStmt.QueryContext(ctx, intfs...)
|
rows, err := selectStmt.QueryContext(ctx, intfs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
||||||
) ([]types.StateBlockNIDList, error) {
|
) ([]types.StateBlockNIDList, error) {
|
||||||
nids := make([]interface{}, len(stateNIDs))
|
nids := make([]interface{}, len(stateNIDs))
|
||||||
for k, v := range stateNIDs {
|
for k, v := range stateNIDs {
|
||||||
|
@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||||
|
|
||||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
Cache: cache,
|
Cache: cache,
|
||||||
Writer: sqlutil.NewExclusiveWriter(),
|
Writer: sqlutil.NewExclusiveWriter(),
|
||||||
EventsTable: events,
|
EventsTable: events,
|
||||||
EventTypesTable: eventTypes,
|
EventTypesTable: eventTypes,
|
||||||
EventStateKeysTable: eventStateKeys,
|
EventStateKeysTable: eventStateKeys,
|
||||||
EventJSONTable: eventJSON,
|
EventJSONTable: eventJSON,
|
||||||
RoomsTable: rooms,
|
RoomsTable: rooms,
|
||||||
StateBlockTable: stateBlock,
|
StateBlockTable: stateBlock,
|
||||||
StateSnapshotTable: stateSnapshot,
|
StateSnapshotTable: stateSnapshot,
|
||||||
PrevEventsTable: prevEvents,
|
PrevEventsTable: prevEvents,
|
||||||
RoomAliasesTable: roomAliases,
|
RoomAliasesTable: roomAliases,
|
||||||
InvitesTable: invites,
|
InvitesTable: invites,
|
||||||
MembershipTable: membership,
|
MembershipTable: membership,
|
||||||
PublishedTable: published,
|
PublishedTable: published,
|
||||||
RedactionsTable: redactions,
|
RedactionsTable: redactions,
|
||||||
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate,
|
GetRoomUpdaterFn: d.GetRoomUpdater,
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetLatestEventsForUpdate(
|
func (d *Database) GetRoomUpdater(
|
||||||
ctx context.Context, roomInfo types.RoomInfo,
|
ctx context.Context, roomInfo *types.RoomInfo,
|
||||||
) (*shared.LatestEventsUpdater, error) {
|
) (*shared.RoomUpdater, error) {
|
||||||
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
||||||
// multiple write transactions on sqlite. The code will perform additional
|
// multiple write transactions on sqlite. The code will perform additional
|
||||||
// write transactions independent of this one which will consistently cause
|
// write transactions independent of this one which will consistently cause
|
||||||
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
||||||
// same DB anyway, and we only execute updates sequentially, the only worries
|
// same DB anyway, and we only execute updates sequentially, the only worries
|
||||||
// are for rolling back when things go wrong. (atomicity)
|
// are for rolling back when things go wrong. (atomicity)
|
||||||
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo)
|
return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) MembershipUpdater(
|
func (d *Database) MembershipUpdater(
|
||||||
|
|
|
@ -18,20 +18,20 @@ type EventJSONPair struct {
|
||||||
type EventJSON interface {
|
type EventJSON interface {
|
||||||
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
|
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
|
||||||
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error
|
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error
|
||||||
BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error)
|
BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type EventTypes interface {
|
type EventTypes interface {
|
||||||
InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
|
InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
|
||||||
SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
|
SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
|
||||||
BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type EventStateKeys interface {
|
type EventStateKeys interface {
|
||||||
InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
|
InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
|
||||||
SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
|
SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
|
||||||
BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||||
BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Events interface {
|
type Events interface {
|
||||||
|
@ -42,12 +42,12 @@ type Events interface {
|
||||||
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
|
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
|
BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error)
|
||||||
BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
|
BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
|
||||||
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||||
BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error)
|
||||||
UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
|
UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
|
||||||
SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error)
|
SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error)
|
||||||
UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error
|
UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error
|
||||||
|
@ -55,12 +55,12 @@ type Events interface {
|
||||||
BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error)
|
BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error)
|
||||||
BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error)
|
BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error)
|
||||||
// BulkSelectEventID returns a map from numeric event ID to string event ID.
|
// BulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||||
BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||||
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||||
// If an event ID is not in the database then it is omitted from the map.
|
// If an event ID is not in the database then it is omitted from the map.
|
||||||
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
|
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
|
||||||
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
|
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
|
||||||
SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
|
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Rooms interface {
|
type Rooms interface {
|
||||||
|
@ -69,29 +69,29 @@ type Rooms interface {
|
||||||
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
|
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
|
||||||
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
|
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
|
||||||
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
||||||
SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
|
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
|
||||||
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
|
||||||
SelectRoomIDs(ctx context.Context) ([]string, error)
|
SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
|
||||||
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
|
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
|
||||||
BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error)
|
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateSnapshot interface {
|
type StateSnapshot interface {
|
||||||
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
|
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
|
||||||
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateBlock interface {
|
type StateBlock interface {
|
||||||
BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error)
|
BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error)
|
||||||
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
|
BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
|
||||||
//BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
//BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type RoomAliases interface {
|
type RoomAliases interface {
|
||||||
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
|
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
|
||||||
SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error)
|
SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error)
|
||||||
SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error)
|
SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error)
|
||||||
SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error)
|
SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error)
|
||||||
DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error)
|
DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ type Invites interface {
|
||||||
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
|
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
|
||||||
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
|
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
|
||||||
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
|
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
|
||||||
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
|
SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MembershipState int64
|
type MembershipState int64
|
||||||
|
@ -121,24 +121,24 @@ const (
|
||||||
type Membership interface {
|
type Membership interface {
|
||||||
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
|
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
|
||||||
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
|
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
|
||||||
SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
|
SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
|
||||||
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
|
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
|
||||||
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||||
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
|
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
|
||||||
// counts of how many rooms they are joined.
|
// counts of how many rooms they are joined.
|
||||||
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
|
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
|
||||||
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||||
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
||||||
SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
|
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
||||||
SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
|
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Published interface {
|
type Published interface {
|
||||||
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error)
|
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error)
|
||||||
SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error)
|
SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error)
|
||||||
SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error)
|
SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type RedactionInfo struct {
|
type RedactionInfo struct {
|
||||||
|
|
Loading…
Reference in a new issue