Complete snapshot isolation for sync

This commit is contained in:
Neil Alexander 2022-09-28 12:55:00 +01:00
parent 3f9e38e80a
commit 09a3c807f9
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
32 changed files with 544 additions and 383 deletions

View file

@ -34,6 +34,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
@ -46,7 +47,7 @@ type OutputClientDataConsumer struct {
topic string topic string
topicReIndex string topicReIndex string
db storage.Database db storage.Database
stream types.StreamProvider stream streams.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
fts *fulltext.Search fts *fulltext.Search
@ -61,7 +62,7 @@ func NewOutputClientDataConsumer(
nats *nats.Conn, nats *nats.Conn,
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream streams.StreamProvider,
fts *fulltext.Search, fts *fulltext.Search,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
return &OutputClientDataConsumer{ return &OutputClientDataConsumer{

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
@ -40,7 +41,7 @@ type OutputKeyChangeEventConsumer struct {
topic string topic string
db storage.Database db storage.Database
notifier *notifier.Notifier notifier *notifier.Notifier
stream types.StreamProvider stream streams.StreamProvider
serverName gomatrixserverlib.ServerName // our server name serverName gomatrixserverlib.ServerName // our server name
rsAPI roomserverAPI.SyncRoomserverAPI rsAPI roomserverAPI.SyncRoomserverAPI
} }
@ -55,7 +56,7 @@ func NewOutputKeyChangeEventConsumer(
rsAPI roomserverAPI.SyncRoomserverAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream streams.StreamProvider,
) *OutputKeyChangeEventConsumer { ) *OutputKeyChangeEventConsumer {
s := &OutputKeyChangeEventConsumer{ s := &OutputKeyChangeEventConsumer{
ctx: process.Context(), ctx: process.Context(),

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -39,7 +40,7 @@ type PresenceConsumer struct {
requestTopic string requestTopic string
presenceTopic string presenceTopic string
db storage.Database db storage.Database
stream types.StreamProvider stream streams.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
deviceAPI api.SyncUserAPI deviceAPI api.SyncUserAPI
cfg *config.SyncAPI cfg *config.SyncAPI
@ -54,7 +55,7 @@ func NewPresenceConsumer(
nats *nats.Conn, nats *nats.Conn,
db storage.Database, db storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream streams.StreamProvider,
deviceAPI api.SyncUserAPI, deviceAPI api.SyncUserAPI,
) *PresenceConsumer { ) *PresenceConsumer {
return &PresenceConsumer{ return &PresenceConsumer{

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
@ -38,7 +39,7 @@ type OutputReceiptEventConsumer struct {
durable string durable string
topic string topic string
db storage.Database db storage.Database
stream types.StreamProvider stream streams.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -51,7 +52,7 @@ func NewOutputReceiptEventConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream streams.StreamProvider,
) *OutputReceiptEventConsumer { ) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{ return &OutputReceiptEventConsumer{
ctx: process.Context(), ctx: process.Context(),

View file

@ -33,6 +33,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
@ -45,8 +46,8 @@ type OutputRoomEventConsumer struct {
durable string durable string
topic string topic string
db storage.Database db storage.Database
pduStream types.StreamProvider pduStream streams.StreamProvider
inviteStream types.StreamProvider inviteStream streams.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
fts *fulltext.Search fts *fulltext.Search
} }
@ -58,8 +59,8 @@ func NewOutputRoomEventConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
pduStream types.StreamProvider, pduStream streams.StreamProvider,
inviteStream types.StreamProvider, inviteStream streams.StreamProvider,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
fts *fulltext.Search, fts *fulltext.Search,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
@ -449,8 +450,14 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head
} }
stateKey := *event.StateKey() stateKey := *event.StateKey()
prevEvent, err := s.db.GetStateEvent( snapshot, err := s.db.NewDatabaseSnapshot(s.ctx)
context.TODO(), event.RoomID(), event.Type(), stateKey, if err != nil {
return nil, err
}
defer snapshot.Rollback() // nolint:err
prevEvent, err := snapshot.GetStateEvent(
s.ctx, event.RoomID(), event.Type(), stateKey,
) )
if err != nil { if err != nil {
return event, err return event, err

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
@ -43,7 +44,7 @@ type OutputSendToDeviceEventConsumer struct {
db storage.Database db storage.Database
keyAPI keyapi.SyncKeyAPI keyAPI keyapi.SyncKeyAPI
serverName gomatrixserverlib.ServerName // our server name serverName gomatrixserverlib.ServerName // our server name
stream types.StreamProvider stream streams.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
} }
@ -56,7 +57,7 @@ func NewOutputSendToDeviceEventConsumer(
store storage.Database, store storage.Database,
keyAPI keyapi.SyncKeyAPI, keyAPI keyapi.SyncKeyAPI,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream streams.StreamProvider,
) *OutputSendToDeviceEventConsumer { ) *OutputSendToDeviceEventConsumer {
return &OutputSendToDeviceEventConsumer{ return &OutputSendToDeviceEventConsumer{
ctx: process.Context(), ctx: process.Context(),

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -36,7 +37,7 @@ type OutputTypingEventConsumer struct {
durable string durable string
topic string topic string
eduCache *caching.EDUCache eduCache *caching.EDUCache
stream types.StreamProvider stream streams.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
} }
@ -48,7 +49,7 @@ func NewOutputTypingEventConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
eduCache *caching.EDUCache, eduCache *caching.EDUCache,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream streams.StreamProvider,
) *OutputTypingEventConsumer { ) *OutputTypingEventConsumer {
return &OutputTypingEventConsumer{ return &OutputTypingEventConsumer{
ctx: process.Context(), ctx: process.Context(),

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
@ -40,7 +41,7 @@ type OutputNotificationDataConsumer struct {
topic string topic string
db storage.Database db storage.Database
notifier *notifier.Notifier notifier *notifier.Notifier
stream types.StreamProvider stream streams.StreamProvider
} }
// NewOutputNotificationDataConsumer creates a new consumer. Call // NewOutputNotificationDataConsumer creates a new consumer. Call
@ -51,7 +52,7 @@ func NewOutputNotificationDataConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream streams.StreamProvider,
) *OutputNotificationDataConsumer { ) *OutputNotificationDataConsumer {
s := &OutputNotificationDataConsumer{ s := &OutputNotificationDataConsumer{
ctx: process.Context(), ctx: process.Context(),

View file

@ -100,7 +100,7 @@ func (ev eventVisibility) allowed() (allowed bool) {
// Returns the filtered events and an error, if any. // Returns the filtered events and an error, if any.
func ApplyHistoryVisibilityFilter( func ApplyHistoryVisibilityFilter(
ctx context.Context, ctx context.Context,
syncDB storage.Database, syncDB storage.DatabaseSnapshot,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent, events []*gomatrixserverlib.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{}, alwaysIncludeEventIDs map[string]struct{},

View file

@ -318,13 +318,20 @@ func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener {
func (n *Notifier) Load(ctx context.Context, db storage.Database) error { func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
n.lock.Lock() n.lock.Lock()
defer n.lock.Unlock() defer n.lock.Unlock()
roomToUsers, err := db.AllJoinedUsersInRooms(ctx)
snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil {
return err
}
defer snapshot.Rollback() // nolint:err
roomToUsers, err := snapshot.AllJoinedUsersInRooms(ctx)
if err != nil { if err != nil {
return err return err
} }
n.setUsersJoinedToRooms(roomToUsers) n.setUsersJoinedToRooms(roomToUsers)
roomToPeekingDevices, err := db.AllPeekingDevicesInRooms(ctx) roomToPeekingDevices, err := snapshot.AllPeekingDevicesInRooms(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -338,7 +345,13 @@ func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs [
n.lock.Lock() n.lock.Lock()
defer n.lock.Unlock() defer n.lock.Unlock()
roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs) snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil {
return err
}
defer snapshot.Rollback() // nolint:err
roomToUsers, err := snapshot.AllJoinedUsersInRoom(ctx, roomIDs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -51,6 +51,12 @@ func Context(
roomID, eventID string, roomID, eventID string,
lazyLoadCache caching.LazyLoadCache, lazyLoadCache caching.LazyLoadCache,
) util.JSONResponse { ) util.JSONResponse {
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
if err != nil {
return jsonerror.InternalServerError()
}
defer snapshot.Rollback() // nolint:err
filter, err := parseRoomEventFilter(req) filter, err := parseRoomEventFilter(req)
if err != nil { if err != nil {
errMsg := "" errMsg := ""
@ -97,7 +103,7 @@ func Context(
ContainsURL: filter.ContainsURL, ContainsURL: filter.ContainsURL,
} }
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) id, requestedEvent, err := snapshot.SelectContextEvent(ctx, roomID, eventID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return util.JSONResponse{ return util.JSONResponse{
@ -111,7 +117,7 @@ func Context(
// verify the user is allowed to see the context for this room/event // verify the user is allowed to see the context for this room/event
startTime := time.Now() startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -127,20 +133,20 @@ func Context(
} }
} }
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter) eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch before events") logrus.WithError(err).Error("unable to fetch before events")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
_, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, roomID, filter) _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch after events") logrus.WithError(err).Error("unable to fetch after events")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
startTime = time.Now() startTime = time.Now()
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID) eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -152,7 +158,7 @@ func Context(
}).Debug("applied history visibility (context eventsBefore/eventsAfter)") }).Debug("applied history visibility (context eventsBefore/eventsAfter)")
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent // TODO: Get the actual state at the last event returned by SelectContextAfterEvent
state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) state, err := snapshot.CurrentState(ctx, roomID, &stateFilter, nil)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to fetch current room state") logrus.WithError(err).Error("unable to fetch current room state")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -173,7 +179,7 @@ func Context(
if len(response.State) > filter.Limit { if len(response.State) > filter.Limit {
response.State = response.State[len(response.State)-filter.Limit:] response.State = response.State[len(response.State)-filter.Limit:]
} }
start, end, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter) start, end, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter)
if err == nil { if err == nil {
response.End = end.String() response.End = end.String()
response.Start = start.String() response.Start = start.String()
@ -188,7 +194,7 @@ func Context(
// by combining the events before and after the context event. Returns the filtered events, // by combining the events before and after the context event. Returns the filtered events,
// and an error, if any. // and an error, if any.
func applyHistoryVisibilityOnContextEvents( func applyHistoryVisibilityOnContextEvents(
ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI, ctx context.Context, snapshot storage.DatabaseSnapshot, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent, eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
userID string, userID string,
) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) { ) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
@ -205,7 +211,7 @@ func applyHistoryVisibilityOnContextEvents(
} }
allEvents := append(eventsBefore, eventsAfter...) allEvents := append(eventsBefore, eventsAfter...)
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, allEvents, nil, userID, "context")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -222,15 +228,15 @@ func applyHistoryVisibilityOnContextEvents(
return filteredBefore, filteredAfter, nil return filteredBefore, filteredAfter, nil
} }
func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { func getStartEnd(ctx context.Context, snapshot storage.DatabaseSnapshot, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
if len(startEvents) > 0 { if len(startEvents) > 0 {
start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID()) start, err = snapshot.EventPositionInTopology(ctx, startEvents[0].EventID())
if err != nil { if err != nil {
return return
} }
} }
if len(endEvents) > 0 { if len(endEvents) > 0 {
end, err = syncDB.EventPositionInTopology(ctx, endEvents[0].EventID()) end, err = snapshot.EventPositionInTopology(ctx, endEvents[0].EventID())
} }
return return
} }

View file

@ -45,8 +45,14 @@ func GetFilter(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
if err != nil {
return jsonerror.InternalServerError()
}
defer snapshot.Rollback() // nolint:err
filter := gomatrixserverlib.DefaultFilter() filter := gomatrixserverlib.DefaultFilter()
if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterID); err != nil { if err := snapshot.GetFilter(req.Context(), &filter, localpart, filterID); err != nil {
//TODO better error handling. This error message is *probably* right, //TODO better error handling. This error message is *probably* right,
// but if there are obscure db errors, this will also be returned, // but if there are obscure db errors, this will also be returned,
// even though it is not correct. // even though it is not correct.

View file

@ -39,6 +39,7 @@ import (
type messagesReq struct { type messagesReq struct {
ctx context.Context ctx context.Context
db storage.Database db storage.Database
snapshot storage.DatabaseSnapshot
rsAPI api.SyncRoomserverAPI rsAPI api.SyncRoomserverAPI
cfg *config.SyncAPI cfg *config.SyncAPI
roomID string roomID string
@ -70,6 +71,12 @@ func OnIncomingMessagesRequest(
) util.JSONResponse { ) util.JSONResponse {
var err error var err error
snapshot, err := db.NewDatabaseSnapshot(req.Context())
if err != nil {
return jsonerror.InternalServerError()
}
defer snapshot.Rollback() // nolint:err
// check if the user has already forgotten about this room // check if the user has already forgotten about this room
isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI) isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI)
if err != nil { if err != nil {
@ -132,7 +139,7 @@ func OnIncomingMessagesRequest(
} }
} else { } else {
fromStream = &streamToken fromStream = &streamToken
from, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering) from, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -154,7 +161,7 @@ func OnIncomingMessagesRequest(
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()), JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
} }
} else { } else {
to, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering) to, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -165,7 +172,7 @@ func OnIncomingMessagesRequest(
// If "to" isn't provided, it defaults to either the earliest stream // If "to" isn't provided, it defaults to either the earliest stream
// position (if we're going backward) or to the latest one (if we're // position (if we're going backward) or to the latest one (if we're
// going forward). // going forward).
to, err = setToDefault(req.Context(), db, backwardOrdering, roomID) to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed") util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -186,6 +193,7 @@ func OnIncomingMessagesRequest(
mReq := messagesReq{ mReq := messagesReq{
ctx: req.Context(), ctx: req.Context(),
db: db, db: db,
snapshot: snapshot,
rsAPI: rsAPI, rsAPI: rsAPI,
cfg: cfg, cfg: cfg,
roomID: roomID, roomID: roomID,
@ -217,7 +225,7 @@ func OnIncomingMessagesRequest(
Start: start.String(), Start: start.String(),
End: end.String(), End: end.String(),
} }
res.applyLazyLoadMembers(req.Context(), db, roomID, device, filter.LazyLoadMembers, lazyLoadCache) res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache)
// If we didn't return any events, set the end to an empty string, so it will be omitted // If we didn't return any events, set the end to an empty string, so it will be omitted
// in the response JSON. // in the response JSON.
@ -239,7 +247,7 @@ func OnIncomingMessagesRequest(
// LazyLoadMembers enabled. // LazyLoadMembers enabled.
func (m *messagesResp) applyLazyLoadMembers( func (m *messagesResp) applyLazyLoadMembers(
ctx context.Context, ctx context.Context,
db storage.Database, db storage.DatabaseSnapshot,
roomID string, roomID string,
device *userapi.Device, device *userapi.Device,
lazyLoad bool, lazyLoad bool,
@ -292,7 +300,7 @@ func (r *messagesReq) retrieveEvents() (
end types.TopologyToken, err error, end types.TopologyToken, err error,
) { ) {
// Retrieve the events from the local database. // Retrieve the events from the local database.
streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
if err != nil { if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err) err = fmt.Errorf("GetEventsInRange: %w", err)
return return
@ -348,7 +356,7 @@ func (r *messagesReq) retrieveEvents() (
// Apply room history visibility filter // Apply room history visibility filter
startTime := time.Now() startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages")
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime), "duration": time.Since(startTime),
"room_id": r.roomID, "room_id": r.roomID,
@ -366,7 +374,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
// else to go. This seems to fix Element iOS from looping on /messages endlessly. // else to go. This seems to fix Element iOS from looping on /messages endlessly.
end = types.TopologyToken{} end = types.TopologyToken{}
} else { } else {
end, err = r.db.EventPositionInTopology( end, err = r.snapshot.EventPositionInTopology(
r.ctx, events[0].EventID(), r.ctx, events[0].EventID(),
) )
// A stream/topological position is a cursor located between two events. // A stream/topological position is a cursor located between two events.
@ -378,7 +386,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
} }
} else { } else {
start = *r.from start = *r.from
end, err = r.db.EventPositionInTopology( end, err = r.snapshot.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(), r.ctx, events[len(events)-1].EventID(),
) )
} }
@ -399,7 +407,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
func (r *messagesReq) handleEmptyEventsSlice() ( func (r *messagesReq) handleEmptyEventsSlice() (
events []*gomatrixserverlib.HeaderedEvent, err error, events []*gomatrixserverlib.HeaderedEvent, err error,
) { ) {
backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID)
// Check if we have backward extremities for this room. // Check if we have backward extremities for this room.
if len(backwardExtremities) > 0 { if len(backwardExtremities) > 0 {
@ -443,7 +451,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
} }
// Check if the slice contains a backward extremity. // Check if the slice contains a backward extremity.
backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID)
if err != nil { if err != nil {
return return
} }
@ -463,7 +471,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
} }
// Append the events ve previously retrieved locally. // Append the events ve previously retrieved locally.
events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...) events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...)
sort.Sort(eventsByDepth(events)) sort.Sort(eventsByDepth(events))
return return
@ -553,7 +561,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
// Returns an error if there was an issue with retrieving the latest position // Returns an error if there was an issue with retrieving the latest position
// from the database // from the database
func setToDefault( func setToDefault(
ctx context.Context, db storage.Database, backwardOrdering bool, ctx context.Context, snapshot storage.DatabaseSnapshot, backwardOrdering bool,
roomID string, roomID string,
) (to types.TopologyToken, err error) { ) (to types.TopologyToken, err error) {
if backwardOrdering { if backwardOrdering {
@ -561,7 +569,7 @@ func setToDefault(
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
to = types.TopologyToken{} to = types.TopologyToken{}
} else { } else {
to, err = db.MaxTopologicalPosition(ctx, roomID) to, err = snapshot.MaxTopologicalPosition(ctx, roomID)
} }
return return

View file

@ -61,8 +61,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
searchReq.SearchCategories.RoomEvents.Filter.Limit = 5 searchReq.SearchCategories.RoomEvents.Filter.Limit = 5
} }
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
if err != nil {
return jsonerror.InternalServerError()
}
defer snapshot.Rollback() // nolint:err
// only search rooms the user is actually joined to // only search rooms the user is actually joined to
joinedRooms, err := syncDB.RoomIDsWithMembership(ctx, device.UserID, "join") joinedRooms, err := snapshot.RoomIDsWithMembership(ctx, device.UserID, "join")
if err != nil { if err != nil {
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -161,12 +167,12 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent) stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent)
for _, event := range evs { for _, event := range evs {
eventsBefore, eventsAfter, err := contextEvents(ctx, syncDB, event, roomFilter, searchReq) eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq)
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to get context events") logrus.WithError(err).Error("failed to get context events")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
startToken, endToken, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter) startToken, endToken, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter)
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to get start/end") logrus.WithError(err).Error("failed to get start/end")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -176,7 +182,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
for _, ev := range append(eventsBefore, eventsAfter...) { for _, ev := range append(eventsBefore, eventsAfter...) {
profile, ok := knownUsersProfiles[event.Sender()] profile, ok := knownUsersProfiles[event.Sender()]
if !ok { if !ok {
stateEvent, err := syncDB.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender()) stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender())
if err != nil { if err != nil {
logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile")
continue continue
@ -209,7 +215,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
groups[event.RoomID()] = roomGroup groups[event.RoomID()] = roomGroup
if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok { if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok {
stateFilter := gomatrixserverlib.DefaultStateFilter() stateFilter := gomatrixserverlib.DefaultStateFilter()
state, err := syncDB.CurrentState(ctx, event.RoomID(), &stateFilter, nil) state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to get current state") logrus.WithError(err).Error("unable to get current state")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -252,24 +258,24 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
// contextEvents returns the events around a given eventID // contextEvents returns the events around a given eventID
func contextEvents( func contextEvents(
ctx context.Context, ctx context.Context,
syncDB storage.Database, snapshot storage.DatabaseSnapshot,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
roomFilter *gomatrixserverlib.RoomEventFilter, roomFilter *gomatrixserverlib.RoomEventFilter,
searchReq SearchRequest, searchReq SearchRequest,
) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) {
id, _, err := syncDB.SelectContextEvent(ctx, event.RoomID(), event.EventID()) id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID(), event.EventID())
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to query context event") logrus.WithError(err).Error("failed to query context event")
return nil, nil, err return nil, nil, err
} }
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter) eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter)
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to query before context event") logrus.WithError(err).Error("failed to query before context event")
return nil, nil, err return nil, nil, err
} }
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit
_, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter) _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter)
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to query after context event") logrus.WithError(err).Error("failed to query after context event")
return nil, nil, err return nil, nil, err

View file

@ -17,19 +17,17 @@ package storage
import ( import (
"context" "context"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
type Database interface { type DatabaseSnapshot interface {
Presence
SharedUsers SharedUsers
Notifications
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
@ -37,6 +35,7 @@ type Database interface {
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@ -44,23 +43,78 @@ type Database interface {
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error) GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error)
PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error)
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
// AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room. // AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room.
AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)
// Events lookups a list of event by their event ID.
// Returns a list of events matching the requested IDs found in the database.
// If an event is not found in the database then it will be omitted from the list.
// Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events.
Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
// GetStateEventsForRoom fetches the state events for a given room.
// Returns an empty slice if no state events could be found for this room.
// Returns an error if there was an issue with the retrieval.
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error)
// GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
// MaxTopologicalPosition returns the highest topological position for a given room.
MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// relevant events within the given ranges for the supplied user ID and device ID.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
// GetFilter looks up the filter associated with a given local user and filter ID
// and populates the target filter. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
GetFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error
// GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
}
type Database interface {
Presence
Notifications
NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseSnapshot, error)
// Events lookups a list of event by their event ID. // Events lookups a list of event by their event ID.
// Returns a list of events matching the requested IDs found in the database. // Returns a list of events matching the requested IDs found in the database.
// If an event is not found in the database then it will be omitted from the list. // If an event is not found in the database then it will be omitted from the list.
@ -77,20 +131,6 @@ type Database interface {
// PurgeRoomState completely purges room state from the sync API. This is done when // PurgeRoomState completely purges room state from the sync API. This is done when
// receiving an output event that completely resets the state. // receiving an output event that completely resets the state.
PurgeRoomState(ctx context.Context, roomID string) error PurgeRoomState(ctx context.Context, roomID string) error
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
// GetStateEventsForRoom fetches the state events for a given room.
// Returns an empty slice if no state events could be found for this room.
// Returns an error if there was an issue with the retrieval.
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error)
// GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error)
// UpsertAccountData keeps track of new or updated account data, by saving the type // UpsertAccountData keeps track of new or updated account data, by saving the type
// of the new/updated data, and the user ID and room ID the data is related to (empty) // of the new/updated data, and the user ID and room ID the data is related to (empty)
// room ID means the data isn't specific to any room) // room ID means the data isn't specific to any room)
@ -114,30 +154,11 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user // DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
// MaxTopologicalPosition returns the highest topological position for a given room.
MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// relevant events within the given ranges for the supplied user ID and device ID.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified // CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
// from position, preventing the send-to-device table from growing indefinitely. // from position, preventing the send-to-device table from growing indefinitely.
CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error) CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error)
// GetFilter looks up the filter associated with a given local user and filter ID
// and populates the target filter. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
GetFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error
// PutFilter puts the passed filter into the database. // PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something // Returns the filterID as a string. Otherwise returns an error if something
// goes wrong. // goes wrong.
@ -146,21 +167,7 @@ type Database interface {
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error
// StoreReceipt stores new receipt events // StoreReceipt stores new receipt events
StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
// GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error)
} }
@ -168,7 +175,6 @@ type Presence interface {
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
} }
type SharedUsers interface { type SharedUsers interface {
@ -179,7 +185,4 @@ type SharedUsers interface {
type Notifications interface { type Notifications interface {
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
} }

View file

@ -55,8 +55,13 @@ type Database struct {
Presence tables.Presence Presence tables.Presence
} }
func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) { type DatabaseSnapshot struct {
return d.DB.BeginTx(ctx, &sql.TxOptions{ *Database
*sql.Tx
}
func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseSnapshot, error) {
txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{
// Set the isolation level so that we see a snapshot of the database. // Set the isolation level so that we see a snapshot of the database.
// In PostgreSQL repeatable read transactions will see a snapshot taken // In PostgreSQL repeatable read transactions will see a snapshot taken
// at the first query, and since the transaction is read-only it can't // at the first query, and since the transaction is read-only it can't
@ -65,9 +70,16 @@ func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) {
Isolation: sql.LevelRepeatableRead, Isolation: sql.LevelRepeatableRead,
ReadOnly: true, ReadOnly: true,
}) })
if err != nil {
return nil, err
}
return &DatabaseSnapshot{
Database: d,
Tx: txn,
}, nil
} }
func (d *Database) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseSnapshot) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) {
id, err := d.OutputEvents.SelectMaxEventID(ctx, nil) id, err := d.OutputEvents.SelectMaxEventID(ctx, nil)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err) return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err)
@ -75,7 +87,7 @@ func (d *Database) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPo
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseSnapshot) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) {
id, err := d.Receipts.SelectMaxReceiptID(ctx, nil) id, err := d.Receipts.SelectMaxReceiptID(ctx, nil)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err) return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err)
@ -83,7 +95,7 @@ func (d *Database) MaxStreamPositionForReceipts(ctx context.Context) (types.Stre
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseSnapshot) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) {
id, err := d.Invites.SelectMaxInviteID(ctx, nil) id, err := d.Invites.SelectMaxInviteID(ctx, nil)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err) return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err)
@ -91,7 +103,7 @@ func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.Strea
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseSnapshot) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil) id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err) return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
@ -99,7 +111,7 @@ func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context)
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseSnapshot) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil) id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err) return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err)
@ -107,7 +119,7 @@ func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.S
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseSnapshot) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.NotificationData.SelectMaxID(ctx, nil) id, err := d.NotificationData.SelectMaxID(ctx, nil)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
@ -115,40 +127,40 @@ func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (ty
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *DatabaseSnapshot) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs) return d.CurrentRoomState.SelectCurrentState(ctx, d.Tx, roomID, stateFilterPart, excludeEventIDs)
} }
func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) { func (d *DatabaseSnapshot) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) {
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.Tx, userID, membership)
} }
func (d *Database) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) { func (d *DatabaseSnapshot) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) {
return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos) return d.Memberships.SelectMembershipCount(ctx, d.Tx, roomID, membership, pos)
} }
func (d *Database) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { func (d *DatabaseSnapshot) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
return d.Memberships.SelectHeroes(ctx, nil, roomID, userID, memberships) return d.Memberships.SelectHeroes(ctx, d.Tx, roomID, userID, memberships)
} }
func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { func (d *DatabaseSnapshot) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) return d.OutputEvents.SelectRecentEvents(ctx, d.Tx, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
} }
func (d *Database) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { func (d *DatabaseSnapshot) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
return d.Topology.SelectPositionInTopology(ctx, nil, eventID) return d.Topology.SelectPositionInTopology(ctx, d.Tx, eventID)
} }
func (d *Database) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { func (d *DatabaseSnapshot) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
return d.Invites.SelectInviteEventsInRange(ctx, nil, targetUserID, r) return d.Invites.SelectInviteEventsInRange(ctx, d.Tx, targetUserID, r)
} }
func (d *Database) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) { func (d *DatabaseSnapshot) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) {
return d.Peeks.SelectPeeksInRange(ctx, nil, userID, deviceID, r) return d.Peeks.SelectPeeksInRange(ctx, d.Tx, userID, deviceID, r)
} }
func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { func (d *DatabaseSnapshot) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos) return d.Receipts.SelectRoomReceiptsAfter(ctx, d.Tx, roomIDs, streamPos)
} }
// Events lookups a list of event by their event ID. // Events lookups a list of event by their event ID.
@ -156,6 +168,17 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
// If an event is not found in the database then it will be omitted from the list. // If an event is not found in the database then it will be omitted from the list.
// Returns an error if there was a problem talking with the database. // Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events. // Does not include any transaction IDs in the returned events.
func (d *DatabaseSnapshot) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.Tx, eventIDs, nil, false)
if err != nil {
return nil, err
}
// We don't include a device here as we only include transaction IDs in
// incremental syncs.
return d.StreamEventsToEvents(nil, streamEvents), nil
}
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
if err != nil { if err != nil {
@ -167,32 +190,32 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixse
return d.StreamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(nil, streamEvents), nil
} }
func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *DatabaseSnapshot) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsers(ctx, nil) return d.CurrentRoomState.SelectJoinedUsers(ctx, d.Tx)
} }
func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { func (d *DatabaseSnapshot) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, nil, roomIDs) return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.Tx, roomIDs)
} }
func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { func (d *DatabaseSnapshot) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
return d.Peeks.SelectPeekingDevices(ctx, nil) return d.Peeks.SelectPeekingDevices(ctx, d.Tx)
} }
func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) { func (d *DatabaseSnapshot) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
return d.CurrentRoomState.SelectSharedUsers(ctx, nil, userID, otherUserIDs) return d.CurrentRoomState.SelectSharedUsers(ctx, d.Tx, userID, otherUserIDs)
} }
func (d *Database) GetStateEvent( func (d *DatabaseSnapshot) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectStateEvent(ctx, nil, roomID, evType, stateKey) return d.CurrentRoomState.SelectStateEvent(ctx, d.Tx, roomID, evType, stateKey)
} }
func (d *Database) GetStateEventsForRoom( func (d *DatabaseSnapshot) GetStateEventsForRoom(
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { ) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) {
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter, nil) stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.Tx, roomID, stateFilter, nil)
return return
} }
@ -273,11 +296,11 @@ func (d *Database) DeletePeeks(
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map // If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error // If there was an issue with the retrieval, returns an error
func (d *Database) GetAccountDataInRange( func (d *DatabaseSnapshot) GetAccountDataInRange(
ctx context.Context, userID string, r types.Range, ctx context.Context, userID string, r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter, accountDataFilterPart *gomatrixserverlib.EventFilter,
) (map[string][]string, types.StreamPosition, error) { ) (map[string][]string, types.StreamPosition, error) {
return d.AccountData.SelectAccountDataInRange(ctx, nil, userID, r, accountDataFilterPart) return d.AccountData.SelectAccountDataInRange(ctx, d.Tx, userID, r, accountDataFilterPart)
} }
// UpsertAccountData keeps track of new or updated account data, by saving the type // UpsertAccountData keeps track of new or updated account data, by saving the type
@ -445,7 +468,7 @@ func (d *Database) updateRoomState(
return nil return nil
} }
func (d *Database) GetEventsInTopologicalRange( func (d *DatabaseSnapshot) GetEventsInTopologicalRange(
ctx context.Context, ctx context.Context,
from, to *types.TopologyToken, from, to *types.TopologyToken,
roomID string, roomID string,
@ -470,52 +493,52 @@ func (d *Database) GetEventsInTopologicalRange(
// Select the event IDs from the defined range. // Select the event IDs from the defined range.
var eIDs []string var eIDs []string
eIDs, err = d.Topology.SelectEventIDsInRange( eIDs, err = d.Topology.SelectEventIDsInRange(
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, ctx, d.Tx, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
) )
if err != nil { if err != nil {
return return
} }
// Retrieve the events' contents using their IDs. // Retrieve the events' contents using their IDs.
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true) events, err = d.OutputEvents.SelectEvents(ctx, d.Tx, eIDs, filter, true)
return return
} }
func (d *Database) BackwardExtremitiesForRoom( func (d *DatabaseSnapshot) BackwardExtremitiesForRoom(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (backwardExtremities map[string][]string, err error) { ) (backwardExtremities map[string][]string, err error) {
return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, nil, roomID) return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.Tx, roomID)
} }
func (d *Database) MaxTopologicalPosition( func (d *DatabaseSnapshot) MaxTopologicalPosition(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.Tx, roomID)
if err != nil { if err != nil {
return types.TopologyToken{}, err return types.TopologyToken{}, err
} }
return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
} }
func (d *Database) EventPositionInTopology( func (d *DatabaseSnapshot) EventPositionInTopology(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID) depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.Tx, eventID)
if err != nil { if err != nil {
return types.TopologyToken{}, err return types.TopologyToken{}, err
} }
return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
} }
func (d *Database) StreamToTopologicalPosition( func (d *DatabaseSnapshot) StreamToTopologicalPosition(
ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, nil, roomID, streamPos, backwardOrdering) topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.Tx, roomID, streamPos, backwardOrdering)
switch { switch {
case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward
return types.TopologyToken{PDUPosition: streamPos}, nil return types.TopologyToken{PDUPosition: streamPos}, nil
case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward
topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.Tx, roomID)
if err != nil { if err != nil {
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err)
} }
@ -527,10 +550,10 @@ func (d *Database) StreamToTopologicalPosition(
} }
} }
func (d *Database) GetFilter( func (d *DatabaseSnapshot) GetFilter(
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
) error { ) error {
return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID) return d.Filter.SelectFilter(ctx, d.Tx, target, localpart, filterID)
} }
func (d *Database) PutFilter( func (d *Database) PutFilter(
@ -569,7 +592,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the // GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
// oldest event in the room's topology. // oldest event in the room's topology.
func (d *Database) GetBackwardTopologyPos( func (d *DatabaseSnapshot) GetBackwardTopologyPos(
ctx context.Context, ctx context.Context,
events []types.StreamEvent, events []types.StreamEvent,
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
@ -577,7 +600,7 @@ func (d *Database) GetBackwardTopologyPos(
if len(events) == 0 { if len(events) == 0 {
return zeroToken, nil return zeroToken, nil
} }
pos, spos, err := d.Topology.SelectPositionInTopology(ctx, nil, events[0].EventID()) pos, spos, err := d.Topology.SelectPositionInTopology(ctx, d.Tx, events[0].EventID())
if err != nil { if err != nil {
return zeroToken, err return zeroToken, err
} }
@ -682,7 +705,7 @@ func (d *Database) fetchMissingStateEvents(
// exclusive of oldPos, inclusive of newPos, for the rooms in which // exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events. // the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it. // A list of joined room IDs is also returned in case the caller needs it.
func (d *Database) GetStateDeltas( func (d *DatabaseSnapshot) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
@ -695,16 +718,10 @@ func (d *Database) GetStateDeltas(
// * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
// * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
// - Get all CURRENTLY joined rooms, and add them to 'joined' block. // - Get all CURRENTLY joined rooms, and add them to 'joined' block.
txn, err := d.readOnlySnapshot(ctx)
if err != nil {
return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
// Look up all memberships for the user. We only care about rooms that a // Look up all memberships for the user. We only care about rooms that a
// user has ever interacted with — joined to, kicked/banned from, left. // user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.Tx, userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil
@ -722,14 +739,14 @@ func (d *Database) GetStateDeltas(
} }
// get all the state events ever (i.e. for all available rooms) between these two positions // get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.Tx, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil
} }
return nil, nil, err return nil, nil, err
} }
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) state, err := d.fetchStateEvents(ctx, d.Tx, stateNeeded, eventMap)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil
@ -739,7 +756,7 @@ func (d *Database) GetStateDeltas(
// find out which rooms this user is peeking, if any. // find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten // We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.Tx, userID, device.ID, r)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, nil, err return nil, nil, err
} }
@ -749,7 +766,7 @@ func (d *Database) GetStateDeltas(
if peek.New { if peek.New {
// send full room state down instead of a delta // send full room state down instead of a delta
var s []types.StreamEvent var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) s, err = d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue continue
@ -775,7 +792,7 @@ func (d *Database) GetStateDeltas(
if membership == gomatrixserverlib.Join && prevMembership != membership { if membership == gomatrixserverlib.Join && prevMembership != membership {
// send full room state down instead of a delta // send full room state down instead of a delta
var s []types.StreamEvent var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) s, err = d.currentStateStreamEventsForRoom(ctx, roomID, stateFilter)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue continue
@ -808,7 +825,6 @@ func (d *Database) GetStateDeltas(
}) })
} }
succeeded = true
return deltas, joinedRoomIDs, nil return deltas, joinedRoomIDs, nil
} }
@ -816,21 +832,14 @@ func (d *Database) GetStateDeltas(
// requests with full_state=true. // requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get // Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms. // updates for other rooms.
func (d *Database) GetStateDeltasForFullStateSync( func (d *DatabaseSnapshot) GetStateDeltasForFullStateSync(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]types.StateDelta, []string, error) { ) ([]types.StateDelta, []string, error) {
txn, err := d.readOnlySnapshot(ctx)
if err != nil {
return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
// Look up all memberships for the user. We only care about rooms that a // Look up all memberships for the user. We only care about rooms that a
// user has ever interacted with — joined to, kicked/banned from, left. // user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.Tx, userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil
@ -850,7 +859,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
// Use a reasonable initial capacity // Use a reasonable initial capacity
deltas := make(map[string]types.StateDelta) deltas := make(map[string]types.StateDelta)
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.Tx, userID, device.ID, r)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, nil, err return nil, nil, err
} }
@ -858,7 +867,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
// Add full states for all peeking rooms // Add full states for all peeking rooms
for _, peek := range peeks { for _, peek := range peeks {
if !peek.Deleted { if !peek.Deleted {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) s, stateErr := d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
if stateErr != nil { if stateErr != nil {
if stateErr == sql.ErrNoRows { if stateErr == sql.ErrNoRows {
continue continue
@ -874,14 +883,14 @@ func (d *Database) GetStateDeltasForFullStateSync(
} }
// Get all the state events ever between these two positions // Get all the state events ever between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.Tx, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil
} }
return nil, nil, err return nil, nil, err
} }
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) state, err := d.fetchStateEvents(ctx, d.Tx, stateNeeded, eventMap)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil
@ -908,7 +917,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
// Add full states for all joined rooms // Add full states for all joined rooms
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) s, stateErr := d.currentStateStreamEventsForRoom(ctx, joinedRoomID, stateFilter)
if stateErr != nil { if stateErr != nil {
if stateErr == sql.ErrNoRows { if stateErr == sql.ErrNoRows {
continue continue
@ -930,15 +939,14 @@ func (d *Database) GetStateDeltasForFullStateSync(
i++ i++
} }
succeeded = true
return result, joinedRoomIDs, nil return result, joinedRoomIDs, nil
} }
func (d *Database) currentStateStreamEventsForRoom( func (d *DatabaseSnapshot) currentStateStreamEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, roomID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
allState, err := d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter, nil) allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.Tx, roomID, stateFilter, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -970,13 +978,13 @@ func (d *Database) StoreNewSendForDeviceMessage(
return newPos, nil return newPos, nil
} }
func (d *Database) SendToDeviceUpdatesForSync( func (d *DatabaseSnapshot) SendToDeviceUpdatesForSync(
ctx context.Context, ctx context.Context,
userID, deviceID string, userID, deviceID string,
from, to types.StreamPosition, from, to types.StreamPosition,
) (types.StreamPosition, []types.SendToDeviceEvent, error) { ) (types.StreamPosition, []types.SendToDeviceEvent, error) {
// First of all, get our send-to-device updates for this user. // First of all, get our send-to-device updates for this user.
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to) lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, d.Tx, userID, deviceID, from, to)
if err != nil { if err != nil {
return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
} }
@ -1023,8 +1031,8 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId
return return
} }
func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) { func (d *DatabaseSnapshot) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) {
_, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos) _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.Tx, roomIDs, streamPos)
return receipts, err return receipts, err
} }
@ -1036,7 +1044,7 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI
return return
} }
func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) { func (d *DatabaseSnapshot) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
roomIDs := make([]string, 0, len(rooms)) roomIDs := make([]string, 0, len(rooms))
for roomID, membership := range rooms { for roomID, membership := range rooms {
if membership != gomatrixserverlib.Join { if membership != gomatrixserverlib.Join {
@ -1044,7 +1052,7 @@ func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context,
} }
roomIDs = append(roomIDs, roomID) roomIDs = append(roomIDs, roomID)
} }
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, nil, userID, roomIDs) return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.Tx, userID, roomIDs)
} }
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {

View file

@ -79,7 +79,13 @@ func TestRecentEventsPDU(t *testing.T) {
// dummy room to make sure SQL queries are filtering on room ID // dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
latest, err := db.MaxStreamPositionForPDUs(ctx) snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil {
t.Fatal(err)
}
defer snapshot.Rollback() // nolint:errcheck
latest, err := snapshot.MaxStreamPositionForPDUs(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err) t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
} }
@ -141,7 +147,7 @@ func TestRecentEventsPDU(t *testing.T) {
t.Run(tc.Name, func(st *testing.T) { t.Run(tc.Name, func(st *testing.T) {
var filter gomatrixserverlib.RoomEventFilter var filter gomatrixserverlib.RoomEventFilter
filter.Limit = tc.Limit filter.Limit = tc.Limit
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{ gotEvents, limited, err := snapshot.RecentEvents(ctx, r.ID, types.Range{
From: tc.From, From: tc.From,
To: tc.To, To: tc.To,
}, &filter, !tc.ReverseOrder, true) }, &filter, !tc.ReverseOrder, true)
@ -178,7 +184,13 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
events := r.Events() events := r.Events()
_ = MustWriteEvents(t, db, events) _ = MustWriteEvents(t, db, events)
from, err := db.MaxTopologicalPosition(ctx, r.ID) snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil {
t.Fatal(err)
}
defer snapshot.Rollback() // nolint:errcheck
from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
if err != nil { if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
} }
@ -188,11 +200,11 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5} filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
if err != nil { if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
} }
gots := db.StreamEventsToEvents(nil, paginatedEvents) gots := snapshot.StreamEventsToEvents(nil, paginatedEvents)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
}) })
} }
@ -414,7 +426,12 @@ func TestSendToDeviceBehaviour(t *testing.T) {
defer closeBase() defer closeBase()
// At this point there should be no messages. We haven't sent anything // At this point there should be no messages. We haven't sent anything
// yet. // yet.
_, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil {
t.Fatal(err)
}
defer snapshot.Rollback() // nolint:errcheck
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -435,7 +452,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position // At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated // that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at. // in the database to reflect that this was the sync position we sent the message at.
streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -446,7 +463,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have one message because we haven't progressed the // At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying // sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position. // with the same position.
streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -460,7 +477,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should now have no updates, because we've progressed the sync // At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again. // position. Therefore the update from before will not be sent again.
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -470,7 +487,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have no updates, because no new updates have been // At this point we should still have no updates, because no new updates have been
// sent. // sent.
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -492,7 +509,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
lastPos = streamPos lastPos = streamPos
} }
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos) _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
if err != nil { if err != nil {
t.Fatalf("unable to get events: %v", err) t.Fatalf("unable to get events: %v", err)
} }

View file

@ -5,22 +5,25 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
type AccountDataStreamProvider struct { type AccountDataStreamProvider struct {
StreamProvider DefaultStreamProvider
userAPI userapi.SyncUserAPI userAPI userapi.SyncUserAPI
} }
func (p *AccountDataStreamProvider) Setup() { func (p *AccountDataStreamProvider) Setup(
p.StreamProvider.Setup() ctx context.Context, snapshot storage.DatabaseSnapshot,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
p.latestMutex.Lock() p.latestMutex.Lock()
defer p.latestMutex.Unlock() defer p.latestMutex.Unlock()
id, err := p.DB.MaxStreamPositionForAccountData(context.Background()) id, err := snapshot.MaxStreamPositionForAccountData(context.Background())
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -29,13 +32,15 @@ func (p *AccountDataStreamProvider) Setup() {
func (p *AccountDataStreamProvider) CompleteSync( func (p *AccountDataStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
} }
func (p *AccountDataStreamProvider) IncrementalSync( func (p *AccountDataStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
@ -44,7 +49,7 @@ func (p *AccountDataStreamProvider) IncrementalSync(
To: to, To: to,
} }
dataTypes, pos, err := p.DB.GetAccountDataInRange( dataTypes, pos, err := snapshot.GetAccountDataInRange(
ctx, req.Device.UserID, r, &req.Filter.AccountData, ctx, req.Device.UserID, r, &req.Filter.AccountData,
) )
if err != nil { if err != nil {

View file

@ -6,17 +6,19 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
type DeviceListStreamProvider struct { type DeviceListStreamProvider struct {
StreamProvider DefaultStreamProvider
rsAPI api.SyncRoomserverAPI rsAPI api.SyncRoomserverAPI
keyAPI keyapi.SyncKeyAPI keyAPI keyapi.SyncKeyAPI
} }
func (p *DeviceListStreamProvider) CompleteSync( func (p *DeviceListStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.LatestPosition(ctx) return p.LatestPosition(ctx)
@ -24,11 +26,12 @@ func (p *DeviceListStreamProvider) CompleteSync(
func (p *DeviceListStreamProvider) IncrementalSync( func (p *DeviceListStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
var err error var err error
to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil { if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed") req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from return from

View file

@ -9,20 +9,23 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
type InviteStreamProvider struct { type InviteStreamProvider struct {
StreamProvider DefaultStreamProvider
} }
func (p *InviteStreamProvider) Setup() { func (p *InviteStreamProvider) Setup(
p.StreamProvider.Setup() ctx context.Context, snapshot storage.DatabaseSnapshot,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
p.latestMutex.Lock() p.latestMutex.Lock()
defer p.latestMutex.Unlock() defer p.latestMutex.Unlock()
id, err := p.DB.MaxStreamPositionForInvites(context.Background()) id, err := snapshot.MaxStreamPositionForInvites(context.Background())
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -31,13 +34,15 @@ func (p *InviteStreamProvider) Setup() {
func (p *InviteStreamProvider) CompleteSync( func (p *InviteStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
} }
func (p *InviteStreamProvider) IncrementalSync( func (p *InviteStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
@ -46,7 +51,7 @@ func (p *InviteStreamProvider) IncrementalSync(
To: to, To: to,
} }
invites, retiredInvites, err := p.DB.InviteEventsInRange( invites, retiredInvites, err := snapshot.InviteEventsInRange(
ctx, req.Device.UserID, r, ctx, req.Device.UserID, r,
) )
if err != nil { if err != nil {

View file

@ -3,17 +3,20 @@ package streams
import ( import (
"context" "context"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
type NotificationDataStreamProvider struct { type NotificationDataStreamProvider struct {
StreamProvider DefaultStreamProvider
} }
func (p *NotificationDataStreamProvider) Setup() { func (p *NotificationDataStreamProvider) Setup(
p.StreamProvider.Setup() ctx context.Context, snapshot storage.DatabaseSnapshot,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
id, err := p.DB.MaxStreamPositionForNotificationData(context.Background()) id, err := snapshot.MaxStreamPositionForNotificationData(context.Background())
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -22,20 +25,22 @@ func (p *NotificationDataStreamProvider) Setup() {
func (p *NotificationDataStreamProvider) CompleteSync( func (p *NotificationDataStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
} }
func (p *NotificationDataStreamProvider) IncrementalSync( func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
from, _ types.StreamPosition, from, _ types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// Get the unread notifications for rooms in our join response. // Get the unread notifications for rooms in our join response.
// This is to ensure clients always have an unread notification section // This is to ensure clients always have an unread notification section
// and can display the correct numbers. // and can display the correct numbers.
countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms) countsByRoom, err := snapshot.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
if err != nil { if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed") req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
return from return from

View file

@ -33,7 +33,7 @@ const PDU_STREAM_WORKERS = 256
const PDU_STREAM_QUEUESIZE = PDU_STREAM_WORKERS * 8 const PDU_STREAM_QUEUESIZE = PDU_STREAM_WORKERS * 8
type PDUStreamProvider struct { type PDUStreamProvider struct {
StreamProvider DefaultStreamProvider
tasks chan func() tasks chan func()
workers atomic.Int32 workers atomic.Int32
@ -63,14 +63,16 @@ func (p *PDUStreamProvider) queue(f func()) {
p.tasks <- f p.tasks <- f
} }
func (p *PDUStreamProvider) Setup() { func (p *PDUStreamProvider) Setup(
p.StreamProvider.Setup() ctx context.Context, snapshot storage.DatabaseSnapshot,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
p.tasks = make(chan func(), PDU_STREAM_QUEUESIZE) p.tasks = make(chan func(), PDU_STREAM_QUEUESIZE)
p.latestMutex.Lock() p.latestMutex.Lock()
defer p.latestMutex.Unlock() defer p.latestMutex.Unlock()
id, err := p.DB.MaxStreamPositionForPDUs(context.Background()) id, err := snapshot.MaxStreamPositionForPDUs(context.Background())
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -79,6 +81,7 @@ func (p *PDUStreamProvider) Setup() {
func (p *PDUStreamProvider) CompleteSync( func (p *PDUStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
from := types.StreamPosition(0) from := types.StreamPosition(0)
@ -94,7 +97,7 @@ func (p *PDUStreamProvider) CompleteSync(
} }
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err := p.DB.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join) joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join)
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed") req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed")
return from return from
@ -103,7 +106,7 @@ func (p *PDUStreamProvider) CompleteSync(
stateFilter := req.Filter.Room.State stateFilter := req.Filter.Room.State
eventFilter := req.Filter.Room.Timeline eventFilter := req.Filter.Room.Timeline
if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil { if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil {
req.Log.WithError(err).Error("unable to update event filter with ignored users") req.Log.WithError(err).Error("unable to update event filter with ignored users")
} }
@ -126,7 +129,7 @@ func (p *PDUStreamProvider) CompleteSync(
defer reqWaitGroup.Done() defer reqWaitGroup.Done()
jr, jerr := p.getJoinResponseForCompleteSync( jr, jerr := p.getJoinResponseForCompleteSync(
ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
) )
if jerr != nil { if jerr != nil {
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
@ -143,7 +146,7 @@ func (p *PDUStreamProvider) CompleteSync(
reqWaitGroup.Wait() reqWaitGroup.Wait()
// Add peeked rooms. // Add peeked rooms.
peeks, err := p.DB.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r) peeks, err := snapshot.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r)
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.DB.PeeksInRange failed") req.Log.WithError(err).Error("p.DB.PeeksInRange failed")
return from return from
@ -152,7 +155,7 @@ func (p *PDUStreamProvider) CompleteSync(
if !peek.Deleted { if !peek.Deleted {
var jr *types.JoinResponse var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync( jr, err = p.getJoinResponseForCompleteSync(
ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
) )
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -167,6 +170,7 @@ func (p *PDUStreamProvider) CompleteSync(
func (p *PDUStreamProvider) IncrementalSync( func (p *PDUStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) (newPos types.StreamPosition) { ) (newPos types.StreamPosition) {
@ -184,12 +188,12 @@ func (p *PDUStreamProvider) IncrementalSync(
eventFilter := req.Filter.Room.Timeline eventFilter := req.Filter.Room.Timeline
if req.WantFullState { if req.WantFullState {
if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
return return
} }
} else { } else {
if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
return return
} }
@ -203,7 +207,7 @@ func (p *PDUStreamProvider) IncrementalSync(
return to return to
} }
if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil { if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil {
req.Log.WithError(err).Error("unable to update event filter with ignored users") req.Log.WithError(err).Error("unable to update event filter with ignored users")
} }
@ -222,7 +226,7 @@ func (p *PDUStreamProvider) IncrementalSync(
} }
} }
var pos types.StreamPosition var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return to return to
} }
@ -244,6 +248,7 @@ func (p *PDUStreamProvider) IncrementalSync(
// nolint:gocyclo // nolint:gocyclo
func (p *PDUStreamProvider) addRoomDeltaToResponse( func (p *PDUStreamProvider) addRoomDeltaToResponse(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
device *userapi.Device, device *userapi.Device,
r types.Range, r types.Range,
delta types.StateDelta, delta types.StateDelta,
@ -260,7 +265,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// This is all "okay" assuming history_visibility == "shared" which it is by default. // This is all "okay" assuming history_visibility == "shared" which it is by default.
r.To = delta.MembershipPos r.To = delta.MembershipPos
} }
recentStreamEvents, limited, err := p.DB.RecentEvents( recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, delta.RoomID, r, ctx, delta.RoomID, r,
eventFilter, true, true, eventFilter, true, true,
) )
@ -270,9 +275,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
} }
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents)
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents) prevBatch, err := snapshot.GetBackwardTopologyPos(ctx, recentStreamEvents)
if err != nil { if err != nil {
return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err) return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
} }
@ -291,7 +296,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
latestPosition := r.To latestPosition := r.To
updateLatestPosition := func(mostRecentEventID string) { updateLatestPosition := func(mostRecentEventID string) {
var pos types.StreamPosition var pos types.StreamPosition
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil { if _, pos, err = snapshot.PositionInTopology(ctx, mostRecentEventID); err == nil {
switch { switch {
case r.Backwards && pos < latestPosition: case r.Backwards && pos < latestPosition:
fallthrough fallthrough
@ -303,7 +308,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
if stateFilter.LazyLoadMembers { if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers( delta.StateEvents, err = p.lazyLoadMembers(
ctx, delta.RoomID, true, limited, stateFilter, ctx, snapshot, delta.RoomID, true, limited, stateFilter,
device, recentEvents, delta.StateEvents, device, recentEvents, delta.StateEvents,
) )
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -320,7 +325,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
// Applies the history visibility rules // Applies the history visibility rules
events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents) events, err := applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
} }
@ -336,7 +341,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if hasMembershipChange { if hasMembershipChange {
p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition) p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition)
} }
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
@ -376,7 +381,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// sure we always return the required events in the timeline. // sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter( func applyHistoryVisibilityFilter(
ctx context.Context, ctx context.Context,
db storage.Database, snapshot storage.DatabaseSnapshot,
rsAPI roomserverAPI.SyncRoomserverAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
roomID, userID string, roomID, userID string,
limit int, limit int,
@ -384,7 +389,7 @@ func applyHistoryVisibilityFilter(
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
// We need to make sure we always include the latest states events, if they are in the timeline. // We need to make sure we always include the latest states events, if they are in the timeline.
// We grep at least limit * 2 events, to ensure we really get the needed events. // We grep at least limit * 2 events, to ensure we really get the needed events.
stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil) stateEvents, err := snapshot.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
if err != nil { if err != nil {
// Not a fatal error, we can continue without the stateEvents, // Not a fatal error, we can continue without the stateEvents,
// they are only needed if there are state events in the timeline. // they are only needed if there are state events in the timeline.
@ -395,7 +400,7 @@ func applyHistoryVisibilityFilter(
alwaysIncludeIDs[ev.EventID()] = struct{}{} alwaysIncludeIDs[ev.EventID()] = struct{}{}
} }
startTime := time.Now() startTime := time.Now()
events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -408,10 +413,10 @@ func applyHistoryVisibilityFilter(
return events, nil return events, nil
} }
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseSnapshot, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
// Work out how many members are in the room. // Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
jr.Summary.JoinedMemberCount = &joinedCount jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount jr.Summary.InvitedMemberCount = &invitedCount
@ -439,7 +444,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe
} }
} }
} }
heroes, err := p.DB.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
if err != nil { if err != nil {
return return
} }
@ -449,6 +454,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe
func (p *PDUStreamProvider) getJoinResponseForCompleteSync( func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
roomID string, roomID string,
r types.Range, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
@ -460,7 +466,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
jr = types.NewJoinResponse() jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events. // TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentStreamEvents, limited, err := p.DB.RecentEvents( recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, roomID, r, eventFilter, true, true, ctx, roomID, r, eventFilter, true, true,
) )
if err != nil { if err != nil {
@ -484,7 +490,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
} }
stateEvents, err := p.DB.CurrentState(ctx, roomID, stateFilter, excludingEventIDs) stateEvents, err := snapshot.CurrentState(ctx, roomID, stateFilter, excludingEventIDs)
if err != nil { if err != nil {
return return
} }
@ -494,7 +500,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
var prevBatch *types.TopologyToken var prevBatch *types.TopologyToken
if len(recentStreamEvents) > 0 { if len(recentStreamEvents) > 0 {
var backwardTopologyPos, backwardStreamPos types.StreamPosition var backwardTopologyPos, backwardStreamPos types.StreamPosition
backwardTopologyPos, backwardStreamPos, err = p.DB.PositionInTopology(ctx, recentStreamEvents[0].EventID()) backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, recentStreamEvents[0].EventID())
if err != nil { if err != nil {
return return
} }
@ -505,18 +511,18 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
prevBatch.Decrement() prevBatch.Decrement()
} }
p.addRoomSummary(ctx, jr, roomID, device.UserID, r.From) p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From)
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
// "Can sync a room with a message with a transaction id" - which does a complete sync to check. // "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
events := recentEvents events := recentEvents
// Only apply history visibility checks if the response is for joined rooms // Only apply history visibility checks if the response is for joined rooms
if !isPeek { if !isPeek {
events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents) events, err = applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
} }
@ -530,7 +536,8 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
if err != nil { if err != nil {
return nil, err return nil, err
} }
stateEvents, err = p.lazyLoadMembers(ctx, roomID, stateEvents, err = p.lazyLoadMembers(
ctx, snapshot, roomID,
false, limited, stateFilter, false, limited, stateFilter,
device, recentEvents, stateEvents, device, recentEvents, stateEvents,
) )
@ -549,7 +556,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
func (p *PDUStreamProvider) lazyLoadMembers( func (p *PDUStreamProvider) lazyLoadMembers(
ctx context.Context, roomID string, ctx context.Context, snapshot storage.DatabaseSnapshot, roomID string,
incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter, incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter,
device *userapi.Device, device *userapi.Device,
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
@ -598,7 +605,7 @@ func (p *PDUStreamProvider) lazyLoadMembers(
filter.Limit = stateFilter.Limit filter.Limit = stateFilter.Limit
filter.Senders = &wantUsers filter.Senders = &wantUsers
filter.Types = &[]string{gomatrixserverlib.MRoomMember} filter.Types = &[]string{gomatrixserverlib.MRoomMember}
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter) memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter)
if err != nil { if err != nil {
return stateEvents, err return stateEvents, err
} }
@ -612,8 +619,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// addIgnoredUsersToFilter adds ignored users to the eventfilter and // addIgnoredUsersToFilter adds ignored users to the eventfilter and
// the syncreq itself for further use in streams. // the syncreq itself for further use in streams.
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseSnapshot, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
ignores, err := p.DB.IgnoresForUser(ctx, req.Device.UserID) ignores, err := snapshot.IgnoresForUser(ctx, req.Device.UserID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil

View file

@ -23,20 +23,23 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
type PresenceStreamProvider struct { type PresenceStreamProvider struct {
StreamProvider DefaultStreamProvider
// cache contains previously sent presence updates to avoid unneeded updates // cache contains previously sent presence updates to avoid unneeded updates
cache sync.Map cache sync.Map
notifier *notifier.Notifier notifier *notifier.Notifier
} }
func (p *PresenceStreamProvider) Setup() { func (p *PresenceStreamProvider) Setup(
p.StreamProvider.Setup() ctx context.Context, snapshot storage.DatabaseSnapshot,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
id, err := p.DB.MaxStreamPositionForPresence(context.Background()) id, err := snapshot.MaxStreamPositionForPresence(context.Background())
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -45,13 +48,15 @@ func (p *PresenceStreamProvider) Setup() {
func (p *PresenceStreamProvider) CompleteSync( func (p *PresenceStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
} }
func (p *PresenceStreamProvider) IncrementalSync( func (p *PresenceStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -4,18 +4,21 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type ReceiptStreamProvider struct { type ReceiptStreamProvider struct {
StreamProvider DefaultStreamProvider
} }
func (p *ReceiptStreamProvider) Setup() { func (p *ReceiptStreamProvider) Setup(
p.StreamProvider.Setup() ctx context.Context, snapshot storage.DatabaseSnapshot,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
id, err := p.DB.MaxStreamPositionForReceipts(context.Background()) id, err := snapshot.MaxStreamPositionForReceipts(context.Background())
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -24,13 +27,15 @@ func (p *ReceiptStreamProvider) Setup() {
func (p *ReceiptStreamProvider) CompleteSync( func (p *ReceiptStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
} }
func (p *ReceiptStreamProvider) IncrementalSync( func (p *ReceiptStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
@ -41,7 +46,7 @@ func (p *ReceiptStreamProvider) IncrementalSync(
} }
} }
lastPos, receipts, err := p.DB.RoomReceiptsAfter(ctx, joinedRooms, from) lastPos, receipts, err := snapshot.RoomReceiptsAfter(ctx, joinedRooms, from)
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.DB.RoomReceiptsAfter failed") req.Log.WithError(err).Error("p.DB.RoomReceiptsAfter failed")
return from return from

View file

@ -3,17 +3,20 @@ package streams
import ( import (
"context" "context"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
type SendToDeviceStreamProvider struct { type SendToDeviceStreamProvider struct {
StreamProvider DefaultStreamProvider
} }
func (p *SendToDeviceStreamProvider) Setup() { func (p *SendToDeviceStreamProvider) Setup(
p.StreamProvider.Setup() ctx context.Context, snapshot storage.DatabaseSnapshot,
) {
p.DefaultStreamProvider.Setup(ctx, snapshot)
id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background()) id, err := snapshot.MaxStreamPositionForSendToDeviceMessages(context.Background())
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -22,18 +25,20 @@ func (p *SendToDeviceStreamProvider) Setup() {
func (p *SendToDeviceStreamProvider) CompleteSync( func (p *SendToDeviceStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
} }
func (p *SendToDeviceStreamProvider) IncrementalSync( func (p *SendToDeviceStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// See if we have any new tasks to do for the send-to-device messaging. // See if we have any new tasks to do for the send-to-device messaging.
lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to) lastPos, events, err := snapshot.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed") req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
return from return from

View file

@ -5,24 +5,27 @@ import (
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type TypingStreamProvider struct { type TypingStreamProvider struct {
StreamProvider DefaultStreamProvider
EDUCache *caching.EDUCache EDUCache *caching.EDUCache
} }
func (p *TypingStreamProvider) CompleteSync( func (p *TypingStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
} }
func (p *TypingStreamProvider) IncrementalSync( func (p *TypingStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -0,0 +1,28 @@
package streams
import (
"context"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type StreamProvider interface {
Setup(ctx context.Context, snapshot storage.DatabaseSnapshot)
// Advance will update the latest position of the stream based on
// an update and will wake callers waiting on StreamNotifyAfter.
Advance(latest types.StreamPosition)
// CompleteSync will update the response to include all updates as needed
// for a complete sync. It will always return immediately.
CompleteSync(ctx context.Context, snapshot storage.DatabaseSnapshot, req *types.SyncRequest) types.StreamPosition
// IncrementalSync will update the response to include all updates between
// the from and to sync positions. It will always return immediately,
// making no changes if the range contains no updates.
IncrementalSync(ctx context.Context, snapshot storage.DatabaseSnapshot, req *types.SyncRequest, from, to types.StreamPosition) types.StreamPosition
// LatestPosition returns the latest stream position for this stream.
LatestPosition(ctx context.Context) types.StreamPosition
}

View file

@ -13,15 +13,15 @@ import (
) )
type Streams struct { type Streams struct {
PDUStreamProvider types.StreamProvider PDUStreamProvider StreamProvider
TypingStreamProvider types.StreamProvider TypingStreamProvider StreamProvider
ReceiptStreamProvider types.StreamProvider ReceiptStreamProvider StreamProvider
InviteStreamProvider types.StreamProvider InviteStreamProvider StreamProvider
SendToDeviceStreamProvider types.StreamProvider SendToDeviceStreamProvider StreamProvider
AccountDataStreamProvider types.StreamProvider AccountDataStreamProvider StreamProvider
DeviceListStreamProvider types.StreamProvider DeviceListStreamProvider StreamProvider
NotificationDataStreamProvider types.StreamProvider NotificationDataStreamProvider StreamProvider
PresenceStreamProvider types.StreamProvider PresenceStreamProvider StreamProvider
} }
func NewSyncStreamProviders( func NewSyncStreamProviders(
@ -31,51 +31,58 @@ func NewSyncStreamProviders(
) *Streams { ) *Streams {
streams := &Streams{ streams := &Streams{
PDUStreamProvider: &PDUStreamProvider{ PDUStreamProvider: &PDUStreamProvider{
StreamProvider: StreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
lazyLoadCache: lazyLoadCache, lazyLoadCache: lazyLoadCache,
rsAPI: rsAPI, rsAPI: rsAPI,
notifier: notifier, notifier: notifier,
}, },
TypingStreamProvider: &TypingStreamProvider{ TypingStreamProvider: &TypingStreamProvider{
StreamProvider: StreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
EDUCache: eduCache, EDUCache: eduCache,
}, },
ReceiptStreamProvider: &ReceiptStreamProvider{ ReceiptStreamProvider: &ReceiptStreamProvider{
StreamProvider: StreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
}, },
InviteStreamProvider: &InviteStreamProvider{ InviteStreamProvider: &InviteStreamProvider{
StreamProvider: StreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
}, },
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{ SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
StreamProvider: StreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
}, },
AccountDataStreamProvider: &AccountDataStreamProvider{ AccountDataStreamProvider: &AccountDataStreamProvider{
StreamProvider: StreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
userAPI: userAPI, userAPI: userAPI,
}, },
NotificationDataStreamProvider: &NotificationDataStreamProvider{ NotificationDataStreamProvider: &NotificationDataStreamProvider{
StreamProvider: StreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
}, },
DeviceListStreamProvider: &DeviceListStreamProvider{ DeviceListStreamProvider: &DeviceListStreamProvider{
StreamProvider: StreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
rsAPI: rsAPI, rsAPI: rsAPI,
keyAPI: keyAPI, keyAPI: keyAPI,
}, },
PresenceStreamProvider: &PresenceStreamProvider{ PresenceStreamProvider: &PresenceStreamProvider{
StreamProvider: StreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
notifier: notifier, notifier: notifier,
}, },
} }
streams.PDUStreamProvider.Setup() ctx := context.TODO()
streams.TypingStreamProvider.Setup() snapshot, err := d.NewDatabaseSnapshot(ctx)
streams.ReceiptStreamProvider.Setup() if err != nil {
streams.InviteStreamProvider.Setup() panic(err)
streams.SendToDeviceStreamProvider.Setup() }
streams.AccountDataStreamProvider.Setup() defer snapshot.Rollback() // nolint:errcheck
streams.NotificationDataStreamProvider.Setup()
streams.DeviceListStreamProvider.Setup() streams.PDUStreamProvider.Setup(ctx, snapshot)
streams.PresenceStreamProvider.Setup() streams.TypingStreamProvider.Setup(ctx, snapshot)
streams.ReceiptStreamProvider.Setup(ctx, snapshot)
streams.InviteStreamProvider.Setup(ctx, snapshot)
streams.SendToDeviceStreamProvider.Setup(ctx, snapshot)
streams.AccountDataStreamProvider.Setup(ctx, snapshot)
streams.NotificationDataStreamProvider.Setup(ctx, snapshot)
streams.DeviceListStreamProvider.Setup(ctx, snapshot)
streams.PresenceStreamProvider.Setup(ctx, snapshot)
return streams return streams
} }

View file

@ -8,16 +8,18 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
type StreamProvider struct { type DefaultStreamProvider struct {
DB storage.Database DB storage.Database
latest types.StreamPosition latest types.StreamPosition
latestMutex sync.RWMutex latestMutex sync.RWMutex
} }
func (p *StreamProvider) Setup() { func (p *DefaultStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseSnapshot,
) {
} }
func (p *StreamProvider) Advance( func (p *DefaultStreamProvider) Advance(
latest types.StreamPosition, latest types.StreamPosition,
) { ) {
p.latestMutex.Lock() p.latestMutex.Lock()
@ -28,7 +30,7 @@ func (p *StreamProvider) Advance(
} }
} }
func (p *StreamProvider) LatestPosition( func (p *DefaultStreamProvider) LatestPosition(
ctx context.Context, ctx context.Context,
) types.StreamPosition { ) types.StreamPosition {
p.latestMutex.RLock() p.latestMutex.RLock()

View file

@ -48,6 +48,13 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
} }
} }
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
if err != nil {
logrus.WithError(err).Error("Failed to acquire database snapshot for sync request")
return nil, err
}
defer snapshot.Rollback() // nolint:errcheck
// Create a default filter and apply a stored filter on top of it (if specified) // Create a default filter and apply a stored filter on top of it (if specified)
filter := gomatrixserverlib.DefaultFilter() filter := gomatrixserverlib.DefaultFilter()
filterQuery := req.URL.Query().Get("filter") filterQuery := req.URL.Query().Get("filter")
@ -64,7 +71,7 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterQuery); err != nil && err != sql.ErrNoRows { if err := snapshot.GetFilter(req.Context(), &filter, localpart, filterQuery); err != nil && err != sql.ErrNoRows {
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed") util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
return nil, fmt.Errorf("syncDB.GetFilter: %w", err) return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
} }

View file

@ -305,6 +305,13 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
} }
snapshot, err := rp.db.NewDatabaseSnapshot(req.Context())
if err != nil {
logrus.WithError(err).Error("Failed to acquire database snapshot for sync request")
return jsonerror.InternalServerError()
}
defer snapshot.Rollback() // nolint:errcheck
if syncReq.Since.IsEmpty() { if syncReq.Since.IsEmpty() {
// Complete sync // Complete sync
syncReq.Response.NextBatch = types.StreamingToken{ syncReq.Response.NextBatch = types.StreamingToken{
@ -312,70 +319,70 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
// might advance while processing other streams, resulting in flakey // might advance while processing other streams, resulting in flakey
// tests. // tests.
DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync( DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
), ),
PDUPosition: rp.streams.PDUStreamProvider.CompleteSync( PDUPosition: rp.streams.PDUStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
), ),
TypingPosition: rp.streams.TypingStreamProvider.CompleteSync( TypingPosition: rp.streams.TypingStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
), ),
ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync( ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
), ),
InvitePosition: rp.streams.InviteStreamProvider.CompleteSync( InvitePosition: rp.streams.InviteStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
), ),
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync( SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
), ),
AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync( AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
), ),
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync( NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
), ),
PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync( PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
), ),
} }
} else { } else {
// Incremental sync // Incremental sync
syncReq.Response.NextBatch = types.StreamingToken{ syncReq.Response.NextBatch = types.StreamingToken{
PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync( PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
syncReq.Since.PDUPosition, currentPos.PDUPosition, syncReq.Since.PDUPosition, currentPos.PDUPosition,
), ),
TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync( TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
syncReq.Since.TypingPosition, currentPos.TypingPosition, syncReq.Since.TypingPosition, currentPos.TypingPosition,
), ),
ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync( ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition, syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition,
), ),
InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync( InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
syncReq.Since.InvitePosition, currentPos.InvitePosition, syncReq.Since.InvitePosition, currentPos.InvitePosition,
), ),
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync( SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition, syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition,
), ),
AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync( AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition,
), ),
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync( NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition, syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition,
), ),
DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition,
), ),
PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync( PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, snapshot, syncReq,
syncReq.Since.PresencePosition, currentPos.PresencePosition, syncReq.Since.PresencePosition, currentPos.PresencePosition,
), ),
} }
@ -437,9 +444,15 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed") util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition) snapshot, err := rp.db.NewDatabaseSnapshot(req.Context())
if err != nil {
logrus.WithError(err).Error("Failed to acquire database snapshot for key change")
return jsonerror.InternalServerError()
}
defer snapshot.Rollback() // nolint:errcheck
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition)
_, _, err = internal.DeviceListCatchup( _, _, err = internal.DeviceListCatchup(
req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition, syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
) )
if err != nil { if err != nil {

View file

@ -41,23 +41,3 @@ func (r *SyncRequest) IsRoomPresent(roomID string) bool {
return false return false
} }
} }
type StreamProvider interface {
Setup()
// Advance will update the latest position of the stream based on
// an update and will wake callers waiting on StreamNotifyAfter.
Advance(latest StreamPosition)
// CompleteSync will update the response to include all updates as needed
// for a complete sync. It will always return immediately.
CompleteSync(ctx context.Context, req *SyncRequest) StreamPosition
// IncrementalSync will update the response to include all updates between
// the from and to sync positions. It will always return immediately,
// making no changes if the range contains no updates.
IncrementalSync(ctx context.Context, req *SyncRequest, from, to StreamPosition) StreamPosition
// LatestPosition returns the latest stream position for this stream.
LatestPosition(ctx context.Context) StreamPosition
}