Transactional isolation for /sync
(#2745)
This should transactional snapshot isolation for `/sync` etc requests. For now we don't use repeatable read due to some odd test failures with invites.
This commit is contained in:
parent
8a82f10046
commit
6348486a13
|
@ -34,6 +34,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ type OutputClientDataConsumer struct {
|
||||||
topic string
|
topic string
|
||||||
topicReIndex string
|
topicReIndex string
|
||||||
db storage.Database
|
db storage.Database
|
||||||
stream types.StreamProvider
|
stream streams.StreamProvider
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
fts *fulltext.Search
|
fts *fulltext.Search
|
||||||
|
@ -61,7 +62,7 @@ func NewOutputClientDataConsumer(
|
||||||
nats *nats.Conn,
|
nats *nats.Conn,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
stream types.StreamProvider,
|
stream streams.StreamProvider,
|
||||||
fts *fulltext.Search,
|
fts *fulltext.Search,
|
||||||
) *OutputClientDataConsumer {
|
) *OutputClientDataConsumer {
|
||||||
return &OutputClientDataConsumer{
|
return &OutputClientDataConsumer{
|
||||||
|
|
|
@ -26,6 +26,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
@ -40,7 +41,7 @@ type OutputKeyChangeEventConsumer struct {
|
||||||
topic string
|
topic string
|
||||||
db storage.Database
|
db storage.Database
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
stream types.StreamProvider
|
stream streams.StreamProvider
|
||||||
serverName gomatrixserverlib.ServerName // our server name
|
serverName gomatrixserverlib.ServerName // our server name
|
||||||
rsAPI roomserverAPI.SyncRoomserverAPI
|
rsAPI roomserverAPI.SyncRoomserverAPI
|
||||||
}
|
}
|
||||||
|
@ -55,7 +56,7 @@ func NewOutputKeyChangeEventConsumer(
|
||||||
rsAPI roomserverAPI.SyncRoomserverAPI,
|
rsAPI roomserverAPI.SyncRoomserverAPI,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
stream types.StreamProvider,
|
stream streams.StreamProvider,
|
||||||
) *OutputKeyChangeEventConsumer {
|
) *OutputKeyChangeEventConsumer {
|
||||||
s := &OutputKeyChangeEventConsumer{
|
s := &OutputKeyChangeEventConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -39,7 +40,7 @@ type PresenceConsumer struct {
|
||||||
requestTopic string
|
requestTopic string
|
||||||
presenceTopic string
|
presenceTopic string
|
||||||
db storage.Database
|
db storage.Database
|
||||||
stream types.StreamProvider
|
stream streams.StreamProvider
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
deviceAPI api.SyncUserAPI
|
deviceAPI api.SyncUserAPI
|
||||||
cfg *config.SyncAPI
|
cfg *config.SyncAPI
|
||||||
|
@ -54,7 +55,7 @@ func NewPresenceConsumer(
|
||||||
nats *nats.Conn,
|
nats *nats.Conn,
|
||||||
db storage.Database,
|
db storage.Database,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
stream types.StreamProvider,
|
stream streams.StreamProvider,
|
||||||
deviceAPI api.SyncUserAPI,
|
deviceAPI api.SyncUserAPI,
|
||||||
) *PresenceConsumer {
|
) *PresenceConsumer {
|
||||||
return &PresenceConsumer{
|
return &PresenceConsumer{
|
||||||
|
|
|
@ -28,6 +28,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ type OutputReceiptEventConsumer struct {
|
||||||
durable string
|
durable string
|
||||||
topic string
|
topic string
|
||||||
db storage.Database
|
db storage.Database
|
||||||
stream types.StreamProvider
|
stream streams.StreamProvider
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
@ -51,7 +52,7 @@ func NewOutputReceiptEventConsumer(
|
||||||
js nats.JetStreamContext,
|
js nats.JetStreamContext,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
stream types.StreamProvider,
|
stream streams.StreamProvider,
|
||||||
) *OutputReceiptEventConsumer {
|
) *OutputReceiptEventConsumer {
|
||||||
return &OutputReceiptEventConsumer{
|
return &OutputReceiptEventConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
|
|
|
@ -33,6 +33,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,8 +46,8 @@ type OutputRoomEventConsumer struct {
|
||||||
durable string
|
durable string
|
||||||
topic string
|
topic string
|
||||||
db storage.Database
|
db storage.Database
|
||||||
pduStream types.StreamProvider
|
pduStream streams.StreamProvider
|
||||||
inviteStream types.StreamProvider
|
inviteStream streams.StreamProvider
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
fts *fulltext.Search
|
fts *fulltext.Search
|
||||||
}
|
}
|
||||||
|
@ -58,8 +59,8 @@ func NewOutputRoomEventConsumer(
|
||||||
js nats.JetStreamContext,
|
js nats.JetStreamContext,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
pduStream types.StreamProvider,
|
pduStream streams.StreamProvider,
|
||||||
inviteStream types.StreamProvider,
|
inviteStream streams.StreamProvider,
|
||||||
rsAPI api.SyncRoomserverAPI,
|
rsAPI api.SyncRoomserverAPI,
|
||||||
fts *fulltext.Search,
|
fts *fulltext.Search,
|
||||||
) *OutputRoomEventConsumer {
|
) *OutputRoomEventConsumer {
|
||||||
|
@ -449,8 +450,14 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head
|
||||||
}
|
}
|
||||||
stateKey := *event.StateKey()
|
stateKey := *event.StateKey()
|
||||||
|
|
||||||
prevEvent, err := s.db.GetStateEvent(
|
snapshot, err := s.db.NewDatabaseSnapshot(s.ctx)
|
||||||
context.TODO(), event.RoomID(), event.Type(), stateKey,
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer snapshot.Rollback() // nolint:errcheck
|
||||||
|
|
||||||
|
prevEvent, err := snapshot.GetStateEvent(
|
||||||
|
s.ctx, event.RoomID(), event.Type(), stateKey,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return event, err
|
return event, err
|
||||||
|
|
|
@ -31,6 +31,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,7 +44,7 @@ type OutputSendToDeviceEventConsumer struct {
|
||||||
db storage.Database
|
db storage.Database
|
||||||
keyAPI keyapi.SyncKeyAPI
|
keyAPI keyapi.SyncKeyAPI
|
||||||
serverName gomatrixserverlib.ServerName // our server name
|
serverName gomatrixserverlib.ServerName // our server name
|
||||||
stream types.StreamProvider
|
stream streams.StreamProvider
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,7 +57,7 @@ func NewOutputSendToDeviceEventConsumer(
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
keyAPI keyapi.SyncKeyAPI,
|
keyAPI keyapi.SyncKeyAPI,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
stream types.StreamProvider,
|
stream streams.StreamProvider,
|
||||||
) *OutputSendToDeviceEventConsumer {
|
) *OutputSendToDeviceEventConsumer {
|
||||||
return &OutputSendToDeviceEventConsumer{
|
return &OutputSendToDeviceEventConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
@ -36,7 +37,7 @@ type OutputTypingEventConsumer struct {
|
||||||
durable string
|
durable string
|
||||||
topic string
|
topic string
|
||||||
eduCache *caching.EDUCache
|
eduCache *caching.EDUCache
|
||||||
stream types.StreamProvider
|
stream streams.StreamProvider
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,7 +49,7 @@ func NewOutputTypingEventConsumer(
|
||||||
js nats.JetStreamContext,
|
js nats.JetStreamContext,
|
||||||
eduCache *caching.EDUCache,
|
eduCache *caching.EDUCache,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
stream types.StreamProvider,
|
stream streams.StreamProvider,
|
||||||
) *OutputTypingEventConsumer {
|
) *OutputTypingEventConsumer {
|
||||||
return &OutputTypingEventConsumer{
|
return &OutputTypingEventConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
|
|
|
@ -28,6 +28,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,7 +41,7 @@ type OutputNotificationDataConsumer struct {
|
||||||
topic string
|
topic string
|
||||||
db storage.Database
|
db storage.Database
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
stream types.StreamProvider
|
stream streams.StreamProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOutputNotificationDataConsumer creates a new consumer. Call
|
// NewOutputNotificationDataConsumer creates a new consumer. Call
|
||||||
|
@ -51,7 +52,7 @@ func NewOutputNotificationDataConsumer(
|
||||||
js nats.JetStreamContext,
|
js nats.JetStreamContext,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
stream types.StreamProvider,
|
stream streams.StreamProvider,
|
||||||
) *OutputNotificationDataConsumer {
|
) *OutputNotificationDataConsumer {
|
||||||
s := &OutputNotificationDataConsumer{
|
s := &OutputNotificationDataConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
|
|
|
@ -100,7 +100,7 @@ func (ev eventVisibility) allowed() (allowed bool) {
|
||||||
// Returns the filtered events and an error, if any.
|
// Returns the filtered events and an error, if any.
|
||||||
func ApplyHistoryVisibilityFilter(
|
func ApplyHistoryVisibilityFilter(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
syncDB storage.Database,
|
syncDB storage.DatabaseTransaction,
|
||||||
rsAPI api.SyncRoomserverAPI,
|
rsAPI api.SyncRoomserverAPI,
|
||||||
events []*gomatrixserverlib.HeaderedEvent,
|
events []*gomatrixserverlib.HeaderedEvent,
|
||||||
alwaysIncludeEventIDs map[string]struct{},
|
alwaysIncludeEventIDs map[string]struct{},
|
||||||
|
|
|
@ -318,13 +318,20 @@ func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener {
|
||||||
func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
|
func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
|
||||||
n.lock.Lock()
|
n.lock.Lock()
|
||||||
defer n.lock.Unlock()
|
defer n.lock.Unlock()
|
||||||
roomToUsers, err := db.AllJoinedUsersInRooms(ctx)
|
|
||||||
|
snapshot, err := db.NewDatabaseSnapshot(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer snapshot.Rollback() // nolint:errcheck
|
||||||
|
|
||||||
|
roomToUsers, err := snapshot.AllJoinedUsersInRooms(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
n.setUsersJoinedToRooms(roomToUsers)
|
n.setUsersJoinedToRooms(roomToUsers)
|
||||||
|
|
||||||
roomToPeekingDevices, err := db.AllPeekingDevicesInRooms(ctx)
|
roomToPeekingDevices, err := snapshot.AllPeekingDevicesInRooms(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -338,7 +345,13 @@ func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs [
|
||||||
n.lock.Lock()
|
n.lock.Lock()
|
||||||
defer n.lock.Unlock()
|
defer n.lock.Unlock()
|
||||||
|
|
||||||
roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs)
|
snapshot, err := db.NewDatabaseSnapshot(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer snapshot.Rollback() // nolint:errcheck
|
||||||
|
|
||||||
|
roomToUsers, err := snapshot.AllJoinedUsersInRoom(ctx, roomIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,6 +51,12 @@ func Context(
|
||||||
roomID, eventID string,
|
roomID, eventID string,
|
||||||
lazyLoadCache caching.LazyLoadCache,
|
lazyLoadCache caching.LazyLoadCache,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
defer snapshot.Rollback() // nolint:errcheck
|
||||||
|
|
||||||
filter, err := parseRoomEventFilter(req)
|
filter, err := parseRoomEventFilter(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := ""
|
errMsg := ""
|
||||||
|
@ -97,7 +103,7 @@ func Context(
|
||||||
ContainsURL: filter.ContainsURL,
|
ContainsURL: filter.ContainsURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
|
id, requestedEvent, err := snapshot.SelectContextEvent(ctx, roomID, eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -111,7 +117,7 @@ func Context(
|
||||||
|
|
||||||
// verify the user is allowed to see the context for this room/event
|
// verify the user is allowed to see the context for this room/event
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
|
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("unable to apply history visibility filter")
|
logrus.WithError(err).Error("unable to apply history visibility filter")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -127,20 +133,20 @@ func Context(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter)
|
eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, roomID, filter)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
logrus.WithError(err).Error("unable to fetch before events")
|
logrus.WithError(err).Error("unable to fetch before events")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
_, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, roomID, filter)
|
_, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, roomID, filter)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
logrus.WithError(err).Error("unable to fetch after events")
|
logrus.WithError(err).Error("unable to fetch after events")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
startTime = time.Now()
|
startTime = time.Now()
|
||||||
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID)
|
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("unable to apply history visibility filter")
|
logrus.WithError(err).Error("unable to apply history visibility filter")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -152,7 +158,7 @@ func Context(
|
||||||
}).Debug("applied history visibility (context eventsBefore/eventsAfter)")
|
}).Debug("applied history visibility (context eventsBefore/eventsAfter)")
|
||||||
|
|
||||||
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
|
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
|
||||||
state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
|
state, err := snapshot.CurrentState(ctx, roomID, &stateFilter, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("unable to fetch current room state")
|
logrus.WithError(err).Error("unable to fetch current room state")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -173,7 +179,7 @@ func Context(
|
||||||
if len(response.State) > filter.Limit {
|
if len(response.State) > filter.Limit {
|
||||||
response.State = response.State[len(response.State)-filter.Limit:]
|
response.State = response.State[len(response.State)-filter.Limit:]
|
||||||
}
|
}
|
||||||
start, end, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter)
|
start, end, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
response.End = end.String()
|
response.End = end.String()
|
||||||
response.Start = start.String()
|
response.Start = start.String()
|
||||||
|
@ -188,7 +194,7 @@ func Context(
|
||||||
// by combining the events before and after the context event. Returns the filtered events,
|
// by combining the events before and after the context event. Returns the filtered events,
|
||||||
// and an error, if any.
|
// and an error, if any.
|
||||||
func applyHistoryVisibilityOnContextEvents(
|
func applyHistoryVisibilityOnContextEvents(
|
||||||
ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI,
|
ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI,
|
||||||
eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
|
eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
|
||||||
userID string,
|
userID string,
|
||||||
) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
|
) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||||
|
@ -205,7 +211,7 @@ func applyHistoryVisibilityOnContextEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
allEvents := append(eventsBefore, eventsAfter...)
|
allEvents := append(eventsBefore, eventsAfter...)
|
||||||
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context")
|
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, allEvents, nil, userID, "context")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -222,15 +228,15 @@ func applyHistoryVisibilityOnContextEvents(
|
||||||
return filteredBefore, filteredAfter, nil
|
return filteredBefore, filteredAfter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
|
func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
|
||||||
if len(startEvents) > 0 {
|
if len(startEvents) > 0 {
|
||||||
start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID())
|
start, err = snapshot.EventPositionInTopology(ctx, startEvents[0].EventID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(endEvents) > 0 {
|
if len(endEvents) > 0 {
|
||||||
end, err = syncDB.EventPositionInTopology(ctx, endEvents[0].EventID())
|
end, err = snapshot.EventPositionInTopology(ctx, endEvents[0].EventID())
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/syncapi/internal"
|
"github.com/matrix-org/dendrite/syncapi/internal"
|
||||||
|
@ -39,6 +40,7 @@ import (
|
||||||
type messagesReq struct {
|
type messagesReq struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
db storage.Database
|
db storage.Database
|
||||||
|
snapshot storage.DatabaseTransaction
|
||||||
rsAPI api.SyncRoomserverAPI
|
rsAPI api.SyncRoomserverAPI
|
||||||
cfg *config.SyncAPI
|
cfg *config.SyncAPI
|
||||||
roomID string
|
roomID string
|
||||||
|
@ -70,6 +72,16 @@ func OnIncomingMessagesRequest(
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
// NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we
|
||||||
|
// expect to be able to write to the database in response to a /messages
|
||||||
|
// request that requires backfilling from the roomserver or federation.
|
||||||
|
snapshot, err := db.NewDatabaseTransaction(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
var succeeded bool
|
||||||
|
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
|
||||||
|
|
||||||
// check if the user has already forgotten about this room
|
// check if the user has already forgotten about this room
|
||||||
isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI)
|
isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -132,7 +144,7 @@ func OnIncomingMessagesRequest(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fromStream = &streamToken
|
fromStream = &streamToken
|
||||||
from, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering)
|
from, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
|
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -154,7 +166,7 @@ func OnIncomingMessagesRequest(
|
||||||
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
|
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
to, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering)
|
to, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
|
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -165,7 +177,7 @@ func OnIncomingMessagesRequest(
|
||||||
// If "to" isn't provided, it defaults to either the earliest stream
|
// If "to" isn't provided, it defaults to either the earliest stream
|
||||||
// position (if we're going backward) or to the latest one (if we're
|
// position (if we're going backward) or to the latest one (if we're
|
||||||
// going forward).
|
// going forward).
|
||||||
to, err = setToDefault(req.Context(), db, backwardOrdering, roomID)
|
to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed")
|
util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -186,6 +198,7 @@ func OnIncomingMessagesRequest(
|
||||||
mReq := messagesReq{
|
mReq := messagesReq{
|
||||||
ctx: req.Context(),
|
ctx: req.Context(),
|
||||||
db: db,
|
db: db,
|
||||||
|
snapshot: snapshot,
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
roomID: roomID,
|
roomID: roomID,
|
||||||
|
@ -217,7 +230,7 @@ func OnIncomingMessagesRequest(
|
||||||
Start: start.String(),
|
Start: start.String(),
|
||||||
End: end.String(),
|
End: end.String(),
|
||||||
}
|
}
|
||||||
res.applyLazyLoadMembers(req.Context(), db, roomID, device, filter.LazyLoadMembers, lazyLoadCache)
|
res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache)
|
||||||
|
|
||||||
// If we didn't return any events, set the end to an empty string, so it will be omitted
|
// If we didn't return any events, set the end to an empty string, so it will be omitted
|
||||||
// in the response JSON.
|
// in the response JSON.
|
||||||
|
@ -229,6 +242,7 @@ func OnIncomingMessagesRequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Respond with the events.
|
// Respond with the events.
|
||||||
|
succeeded = true
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: res,
|
JSON: res,
|
||||||
|
@ -239,7 +253,7 @@ func OnIncomingMessagesRequest(
|
||||||
// LazyLoadMembers enabled.
|
// LazyLoadMembers enabled.
|
||||||
func (m *messagesResp) applyLazyLoadMembers(
|
func (m *messagesResp) applyLazyLoadMembers(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db storage.Database,
|
db storage.DatabaseTransaction,
|
||||||
roomID string,
|
roomID string,
|
||||||
device *userapi.Device,
|
device *userapi.Device,
|
||||||
lazyLoad bool,
|
lazyLoad bool,
|
||||||
|
@ -292,7 +306,7 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
end types.TopologyToken, err error,
|
end types.TopologyToken, err error,
|
||||||
) {
|
) {
|
||||||
// Retrieve the events from the local database.
|
// Retrieve the events from the local database.
|
||||||
streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
|
streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("GetEventsInRange: %w", err)
|
err = fmt.Errorf("GetEventsInRange: %w", err)
|
||||||
return
|
return
|
||||||
|
@ -348,7 +362,7 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
|
|
||||||
// Apply room history visibility filter
|
// Apply room history visibility filter
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages")
|
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages")
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"duration": time.Since(startTime),
|
"duration": time.Since(startTime),
|
||||||
"room_id": r.roomID,
|
"room_id": r.roomID,
|
||||||
|
@ -366,7 +380,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
|
||||||
// else to go. This seems to fix Element iOS from looping on /messages endlessly.
|
// else to go. This seems to fix Element iOS from looping on /messages endlessly.
|
||||||
end = types.TopologyToken{}
|
end = types.TopologyToken{}
|
||||||
} else {
|
} else {
|
||||||
end, err = r.db.EventPositionInTopology(
|
end, err = r.snapshot.EventPositionInTopology(
|
||||||
r.ctx, events[0].EventID(),
|
r.ctx, events[0].EventID(),
|
||||||
)
|
)
|
||||||
// A stream/topological position is a cursor located between two events.
|
// A stream/topological position is a cursor located between two events.
|
||||||
|
@ -378,7 +392,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
start = *r.from
|
start = *r.from
|
||||||
end, err = r.db.EventPositionInTopology(
|
end, err = r.snapshot.EventPositionInTopology(
|
||||||
r.ctx, events[len(events)-1].EventID(),
|
r.ctx, events[len(events)-1].EventID(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -399,7 +413,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
|
||||||
func (r *messagesReq) handleEmptyEventsSlice() (
|
func (r *messagesReq) handleEmptyEventsSlice() (
|
||||||
events []*gomatrixserverlib.HeaderedEvent, err error,
|
events []*gomatrixserverlib.HeaderedEvent, err error,
|
||||||
) {
|
) {
|
||||||
backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
|
backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID)
|
||||||
|
|
||||||
// Check if we have backward extremities for this room.
|
// Check if we have backward extremities for this room.
|
||||||
if len(backwardExtremities) > 0 {
|
if len(backwardExtremities) > 0 {
|
||||||
|
@ -443,7 +457,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the slice contains a backward extremity.
|
// Check if the slice contains a backward extremity.
|
||||||
backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
|
backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -463,7 +477,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append the events ve previously retrieved locally.
|
// Append the events ve previously retrieved locally.
|
||||||
events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...)
|
events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...)
|
||||||
sort.Sort(eventsByDepth(events))
|
sort.Sort(eventsByDepth(events))
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -553,7 +567,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
|
||||||
// Returns an error if there was an issue with retrieving the latest position
|
// Returns an error if there was an issue with retrieving the latest position
|
||||||
// from the database
|
// from the database
|
||||||
func setToDefault(
|
func setToDefault(
|
||||||
ctx context.Context, db storage.Database, backwardOrdering bool,
|
ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool,
|
||||||
roomID string,
|
roomID string,
|
||||||
) (to types.TopologyToken, err error) {
|
) (to types.TopologyToken, err error) {
|
||||||
if backwardOrdering {
|
if backwardOrdering {
|
||||||
|
@ -561,7 +575,7 @@ func setToDefault(
|
||||||
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
|
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
|
||||||
to = types.TopologyToken{}
|
to = types.TopologyToken{}
|
||||||
} else {
|
} else {
|
||||||
to, err = db.MaxTopologicalPosition(ctx, roomID)
|
to, err = snapshot.MaxTopologicalPosition(ctx, roomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
@ -61,8 +61,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
||||||
searchReq.SearchCategories.RoomEvents.Filter.Limit = 5
|
searchReq.SearchCategories.RoomEvents.Filter.Limit = 5
|
||||||
}
|
}
|
||||||
|
|
||||||
|
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
defer snapshot.Rollback() // nolint:errcheck
|
||||||
|
|
||||||
// only search rooms the user is actually joined to
|
// only search rooms the user is actually joined to
|
||||||
joinedRooms, err := syncDB.RoomIDsWithMembership(ctx, device.UserID, "join")
|
joinedRooms, err := snapshot.RoomIDsWithMembership(ctx, device.UserID, "join")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -161,12 +167,12 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
||||||
|
|
||||||
stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent)
|
stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent)
|
||||||
for _, event := range evs {
|
for _, event := range evs {
|
||||||
eventsBefore, eventsAfter, err := contextEvents(ctx, syncDB, event, roomFilter, searchReq)
|
eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("failed to get context events")
|
logrus.WithError(err).Error("failed to get context events")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
startToken, endToken, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter)
|
startToken, endToken, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("failed to get start/end")
|
logrus.WithError(err).Error("failed to get start/end")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -176,7 +182,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
||||||
for _, ev := range append(eventsBefore, eventsAfter...) {
|
for _, ev := range append(eventsBefore, eventsAfter...) {
|
||||||
profile, ok := knownUsersProfiles[event.Sender()]
|
profile, ok := knownUsersProfiles[event.Sender()]
|
||||||
if !ok {
|
if !ok {
|
||||||
stateEvent, err := syncDB.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender())
|
stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile")
|
logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile")
|
||||||
continue
|
continue
|
||||||
|
@ -209,7 +215,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
||||||
groups[event.RoomID()] = roomGroup
|
groups[event.RoomID()] = roomGroup
|
||||||
if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok {
|
if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok {
|
||||||
stateFilter := gomatrixserverlib.DefaultStateFilter()
|
stateFilter := gomatrixserverlib.DefaultStateFilter()
|
||||||
state, err := syncDB.CurrentState(ctx, event.RoomID(), &stateFilter, nil)
|
state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("unable to get current state")
|
logrus.WithError(err).Error("unable to get current state")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -252,24 +258,24 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
||||||
// contextEvents returns the events around a given eventID
|
// contextEvents returns the events around a given eventID
|
||||||
func contextEvents(
|
func contextEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
syncDB storage.Database,
|
snapshot storage.DatabaseTransaction,
|
||||||
event *gomatrixserverlib.HeaderedEvent,
|
event *gomatrixserverlib.HeaderedEvent,
|
||||||
roomFilter *gomatrixserverlib.RoomEventFilter,
|
roomFilter *gomatrixserverlib.RoomEventFilter,
|
||||||
searchReq SearchRequest,
|
searchReq SearchRequest,
|
||||||
) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) {
|
) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
id, _, err := syncDB.SelectContextEvent(ctx, event.RoomID(), event.EventID())
|
id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID(), event.EventID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("failed to query context event")
|
logrus.WithError(err).Error("failed to query context event")
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit
|
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit
|
||||||
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter)
|
eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("failed to query before context event")
|
logrus.WithError(err).Error("failed to query before context event")
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit
|
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit
|
||||||
_, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter)
|
_, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("failed to query after context event")
|
logrus.WithError(err).Error("failed to query after context event")
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|
|
@ -17,19 +17,17 @@ package storage
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Database interface {
|
type DatabaseTransaction interface {
|
||||||
Presence
|
|
||||||
SharedUsers
|
SharedUsers
|
||||||
Notifications
|
|
||||||
|
|
||||||
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
|
||||||
|
@ -37,6 +35,7 @@ type Database interface {
|
||||||
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
|
||||||
|
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
|
||||||
|
|
||||||
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||||
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
||||||
|
@ -44,23 +43,77 @@ type Database interface {
|
||||||
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
|
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
|
||||||
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
|
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
|
||||||
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)
|
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)
|
||||||
|
|
||||||
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||||
|
|
||||||
GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error)
|
GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error)
|
||||||
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
|
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
|
||||||
|
InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error)
|
||||||
InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error)
|
|
||||||
PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
|
PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
|
||||||
RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error)
|
RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error)
|
||||||
|
|
||||||
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
|
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
|
||||||
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
|
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
|
||||||
// AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room.
|
// AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room.
|
||||||
AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
|
AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
|
||||||
|
|
||||||
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
|
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
|
||||||
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)
|
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)
|
||||||
|
// Events lookups a list of event by their event ID.
|
||||||
|
// Returns a list of events matching the requested IDs found in the database.
|
||||||
|
// If an event is not found in the database then it will be omitted from the list.
|
||||||
|
// Returns an error if there was a problem talking with the database.
|
||||||
|
// Does not include any transaction IDs in the returned events.
|
||||||
|
Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
|
||||||
|
// If no event could be found, returns nil
|
||||||
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
// GetStateEventsForRoom fetches the state events for a given room.
|
||||||
|
// Returns an empty slice if no state events could be found for this room.
|
||||||
|
// Returns an error if there was an issue with the retrieval.
|
||||||
|
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error)
|
||||||
|
// GetAccountDataInRange returns all account data for a given user inserted or
|
||||||
|
// updated between two given positions
|
||||||
|
// Returns a map following the format data[roomID] = []dataTypes
|
||||||
|
// If no data is retrieved, returns an empty map
|
||||||
|
// If there was an issue with the retrieval, returns an error
|
||||||
|
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error)
|
||||||
|
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
|
||||||
|
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||||
|
// EventPositionInTopology returns the depth and stream position of the given event.
|
||||||
|
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
||||||
|
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
||||||
|
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
|
||||||
|
// MaxTopologicalPosition returns the highest topological position for a given room.
|
||||||
|
MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
|
||||||
|
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
|
||||||
|
// matches the streamevent.transactionID device then the transaction ID gets
|
||||||
|
// added to the unsigned section of the output event.
|
||||||
|
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
|
||||||
|
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
|
||||||
|
// relevant events within the given ranges for the supplied user ID and device ID.
|
||||||
|
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
|
||||||
|
// GetRoomReceipts gets all receipts for a given roomID
|
||||||
|
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
|
||||||
|
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
|
||||||
|
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
|
||||||
|
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
|
||||||
|
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
|
||||||
|
// string as the membership.
|
||||||
|
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
|
||||||
|
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
|
||||||
|
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
|
||||||
|
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
||||||
|
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Database interface {
|
||||||
|
Presence
|
||||||
|
Notifications
|
||||||
|
|
||||||
|
NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error)
|
||||||
|
NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error)
|
||||||
|
|
||||||
// Events lookups a list of event by their event ID.
|
// Events lookups a list of event by their event ID.
|
||||||
// Returns a list of events matching the requested IDs found in the database.
|
// Returns a list of events matching the requested IDs found in the database.
|
||||||
// If an event is not found in the database then it will be omitted from the list.
|
// If an event is not found in the database then it will be omitted from the list.
|
||||||
|
@ -77,20 +130,6 @@ type Database interface {
|
||||||
// PurgeRoomState completely purges room state from the sync API. This is done when
|
// PurgeRoomState completely purges room state from the sync API. This is done when
|
||||||
// receiving an output event that completely resets the state.
|
// receiving an output event that completely resets the state.
|
||||||
PurgeRoomState(ctx context.Context, roomID string) error
|
PurgeRoomState(ctx context.Context, roomID string) error
|
||||||
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
|
|
||||||
// If no event could be found, returns nil
|
|
||||||
// If there was an issue during the retrieval, returns an error
|
|
||||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
|
||||||
// GetStateEventsForRoom fetches the state events for a given room.
|
|
||||||
// Returns an empty slice if no state events could be found for this room.
|
|
||||||
// Returns an error if there was an issue with the retrieval.
|
|
||||||
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error)
|
|
||||||
// GetAccountDataInRange returns all account data for a given user inserted or
|
|
||||||
// updated between two given positions
|
|
||||||
// Returns a map following the format data[roomID] = []dataTypes
|
|
||||||
// If no data is retrieved, returns an empty map
|
|
||||||
// If there was an issue with the retrieval, returns an error
|
|
||||||
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error)
|
|
||||||
// UpsertAccountData keeps track of new or updated account data, by saving the type
|
// UpsertAccountData keeps track of new or updated account data, by saving the type
|
||||||
// of the new/updated data, and the user ID and room ID the data is related to (empty)
|
// of the new/updated data, and the user ID and room ID the data is related to (empty)
|
||||||
// room ID means the data isn't specific to any room)
|
// room ID means the data isn't specific to any room)
|
||||||
|
@ -114,21 +153,6 @@ type Database interface {
|
||||||
// DeletePeek deletes all peeks for a given room by a given user
|
// DeletePeek deletes all peeks for a given room by a given user
|
||||||
// Returns an error if there was a problem communicating with the database.
|
// Returns an error if there was a problem communicating with the database.
|
||||||
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
||||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
|
|
||||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
|
|
||||||
// EventPositionInTopology returns the depth and stream position of the given event.
|
|
||||||
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
|
||||||
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
|
||||||
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
|
|
||||||
// MaxTopologicalPosition returns the highest topological position for a given room.
|
|
||||||
MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
|
|
||||||
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
|
|
||||||
// matches the streamevent.transactionID device then the transaction ID gets
|
|
||||||
// added to the unsigned section of the output event.
|
|
||||||
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
|
|
||||||
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
|
|
||||||
// relevant events within the given ranges for the supplied user ID and device ID.
|
|
||||||
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
|
|
||||||
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
|
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
|
||||||
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
||||||
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
|
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
|
||||||
|
@ -146,29 +170,13 @@ type Database interface {
|
||||||
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error
|
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error
|
||||||
// StoreReceipt stores new receipt events
|
// StoreReceipt stores new receipt events
|
||||||
StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
|
StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
|
||||||
// GetRoomReceipts gets all receipts for a given roomID
|
|
||||||
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
|
|
||||||
|
|
||||||
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
|
|
||||||
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
|
|
||||||
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
|
|
||||||
|
|
||||||
StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
|
|
||||||
|
|
||||||
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
|
|
||||||
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
|
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
|
||||||
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
|
|
||||||
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
|
|
||||||
// string as the membership.
|
|
||||||
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
|
|
||||||
ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error)
|
ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Presence interface {
|
type Presence interface {
|
||||||
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
|
|
||||||
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
||||||
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SharedUsers interface {
|
type SharedUsers interface {
|
||||||
|
@ -179,7 +187,4 @@ type SharedUsers interface {
|
||||||
type Notifications interface {
|
type Notifications interface {
|
||||||
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
|
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
|
||||||
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
|
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
|
||||||
|
|
||||||
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
|
|
||||||
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,7 +55,7 @@ const deleteInviteEventSQL = "" +
|
||||||
"UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id"
|
"UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id"
|
||||||
|
|
||||||
const selectInviteEventsInRangeSQL = "" +
|
const selectInviteEventsInRangeSQL = "" +
|
||||||
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
|
"SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" +
|
||||||
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
|
||||||
" ORDER BY id DESC"
|
" ORDER BY id DESC"
|
||||||
|
|
||||||
|
@ -121,23 +121,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
|
||||||
// active invites for the target user ID in the supplied range.
|
// active invites for the target user ID in the supplied range.
|
||||||
func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
||||||
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
|
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
|
||||||
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
|
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
|
||||||
|
var lastPos types.StreamPosition
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
|
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, lastPos, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
|
||||||
result := map[string]*gomatrixserverlib.HeaderedEvent{}
|
result := map[string]*gomatrixserverlib.HeaderedEvent{}
|
||||||
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
|
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
|
id types.StreamPosition
|
||||||
roomID string
|
roomID string
|
||||||
eventJSON []byte
|
eventJSON []byte
|
||||||
deleted bool
|
deleted bool
|
||||||
)
|
)
|
||||||
if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
|
if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, lastPos, err
|
||||||
|
}
|
||||||
|
if id > lastPos {
|
||||||
|
lastPos = id
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we have seen this room before, it has a higher stream position and hence takes priority
|
// if we have seen this room before, it has a higher stream position and hence takes priority
|
||||||
|
@ -150,7 +155,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
||||||
|
|
||||||
var event *gomatrixserverlib.HeaderedEvent
|
var event *gomatrixserverlib.HeaderedEvent
|
||||||
if err := json.Unmarshal(eventJSON, &event); err != nil {
|
if err := json.Unmarshal(eventJSON, &event); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, lastPos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if deleted {
|
if deleted {
|
||||||
|
@ -159,7 +164,10 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
||||||
result[roomID] = event
|
result[roomID] = event
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, retired, rows.Err()
|
if lastPos == 0 {
|
||||||
|
lastPos = r.To
|
||||||
|
}
|
||||||
|
return result, retired, lastPos, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteEventsStatements) SelectMaxInviteID(
|
func (s *inviteEventsStatements) SelectMaxInviteID(
|
||||||
|
|
586
syncapi/storage/shared/storage_consumer.go
Normal file
586
syncapi/storage/shared/storage_consumer.go
Normal file
|
@ -0,0 +1,586 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
|
||||||
|
// For now this contains the shared functions
|
||||||
|
type Database struct {
|
||||||
|
DB *sql.DB
|
||||||
|
Writer sqlutil.Writer
|
||||||
|
Invites tables.Invites
|
||||||
|
Peeks tables.Peeks
|
||||||
|
AccountData tables.AccountData
|
||||||
|
OutputEvents tables.Events
|
||||||
|
Topology tables.Topology
|
||||||
|
CurrentRoomState tables.CurrentRoomState
|
||||||
|
BackwardExtremities tables.BackwardsExtremities
|
||||||
|
SendToDevice tables.SendToDevice
|
||||||
|
Filter tables.Filter
|
||||||
|
Receipts tables.Receipts
|
||||||
|
Memberships tables.Memberships
|
||||||
|
NotificationData tables.NotificationData
|
||||||
|
Ignores tables.Ignores
|
||||||
|
Presence tables.Presence
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) {
|
||||||
|
return d.NewDatabaseTransaction(ctx)
|
||||||
|
|
||||||
|
/*
|
||||||
|
TODO: Repeatable read is probably the right thing to do here,
|
||||||
|
but it seems to cause some problems with the invite tests, so
|
||||||
|
need to investigate that further.
|
||||||
|
|
||||||
|
txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{
|
||||||
|
// Set the isolation level so that we see a snapshot of the database.
|
||||||
|
// In PostgreSQL repeatable read transactions will see a snapshot taken
|
||||||
|
// at the first query, and since the transaction is read-only it can't
|
||||||
|
// run into any serialisation errors.
|
||||||
|
// https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
|
||||||
|
Isolation: sql.LevelRepeatableRead,
|
||||||
|
ReadOnly: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &DatabaseTransaction{
|
||||||
|
Database: d,
|
||||||
|
txn: txn,
|
||||||
|
}, nil
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) {
|
||||||
|
txn, err := d.DB.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &DatabaseTransaction{
|
||||||
|
Database: d,
|
||||||
|
txn: txn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't include a device here as we only include transaction IDs in
|
||||||
|
// incremental syncs.
|
||||||
|
return d.StreamEventsToEvents(nil, streamEvents), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddInviteEvent stores a new invite event for a user.
|
||||||
|
// If the invite was successfully stored this returns the stream ID it was stored at.
|
||||||
|
// Returns an error if there was a problem communicating with the database.
|
||||||
|
func (d *Database) AddInviteEvent(
|
||||||
|
ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent,
|
||||||
|
) (sp types.StreamPosition, err error) {
|
||||||
|
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetireInviteEvent removes an old invite event from the database.
|
||||||
|
// Returns an error if there was a problem communicating with the database.
|
||||||
|
func (d *Database) RetireInviteEvent(
|
||||||
|
ctx context.Context, inviteEventID string,
|
||||||
|
) (sp types.StreamPosition, err error) {
|
||||||
|
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeek tracks the fact that a user has started peeking.
|
||||||
|
// If the peek was successfully stored this returns the stream ID it was stored at.
|
||||||
|
// Returns an error if there was a problem communicating with the database.
|
||||||
|
func (d *Database) AddPeek(
|
||||||
|
ctx context.Context, roomID, userID, deviceID string,
|
||||||
|
) (sp types.StreamPosition, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeek tracks the fact that a user has stopped peeking from the specified
|
||||||
|
// device. If the peeks was successfully deleted this returns the stream ID it was
|
||||||
|
// stored at. Returns an error if there was a problem communicating with the database.
|
||||||
|
func (d *Database) DeletePeek(
|
||||||
|
ctx context.Context, roomID, userID, deviceID string,
|
||||||
|
) (sp types.StreamPosition, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
sp = 0
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeeks tracks the fact that a user has stopped peeking from all devices
|
||||||
|
// If the peeks was successfully deleted this returns the stream ID it was stored at.
|
||||||
|
// Returns an error if there was a problem communicating with the database.
|
||||||
|
func (d *Database) DeletePeeks(
|
||||||
|
ctx context.Context, roomID, userID string,
|
||||||
|
) (sp types.StreamPosition, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
sp = 0
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsertAccountData keeps track of new or updated account data, by saving the type
|
||||||
|
// of the new/updated data, and the user ID and room ID the data is related to (empty)
|
||||||
|
// room ID means the data isn't specific to any room)
|
||||||
|
// If no data with the given type, user ID and room ID exists in the database,
|
||||||
|
// creates a new row, else update the existing one
|
||||||
|
// Returns an error if there was an issue with the upsert
|
||||||
|
func (d *Database) UpsertAccountData(
|
||||||
|
ctx context.Context, userID, roomID, dataType string,
|
||||||
|
) (sp types.StreamPosition, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||||
|
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||||
|
for i := 0; i < len(in); i++ {
|
||||||
|
out[i] = in[i].HeaderedEvent
|
||||||
|
if device != nil && in[i].TransactionID != nil {
|
||||||
|
if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
|
||||||
|
err := out[i].SetUnsignedField(
|
||||||
|
"transaction_id", in[i].TransactionID.TransactionID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"event_id": out[i].EventID(),
|
||||||
|
}).WithError(err).Warnf("Failed to add transaction ID to event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
|
||||||
|
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
|
||||||
|
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
|
||||||
|
// This function should always be called within a sqlutil.Writer for safety in SQLite.
|
||||||
|
func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
|
||||||
|
if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have all of the event's previous events. If an event is
|
||||||
|
// missing, add it to the room's backward extremities.
|
||||||
|
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var found bool
|
||||||
|
for _, eID := range ev.PrevEventIDs() {
|
||||||
|
found = false
|
||||||
|
for _, prevEv := range prevEvents {
|
||||||
|
if eID == prevEv.EventID() {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the event is missing, consider it a backward extremity.
|
||||||
|
if !found {
|
||||||
|
if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) PurgeRoomState(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
// If the event is a create event then we'll delete all of the existing
|
||||||
|
// data for the room. The only reason that a create event would be replayed
|
||||||
|
// to us in this way is if we're about to receive the entire room state.
|
||||||
|
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
|
||||||
|
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) WriteEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
ev *gomatrixserverlib.HeaderedEvent,
|
||||||
|
addStateEvents []*gomatrixserverlib.HeaderedEvent,
|
||||||
|
addStateEventIDs, removeStateEventIDs []string,
|
||||||
|
transactionID *api.TransactionID, excludeFromSync bool,
|
||||||
|
historyVisibility gomatrixserverlib.HistoryVisibility,
|
||||||
|
) (pduPosition types.StreamPosition, returnErr error) {
|
||||||
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
ev.Visibility = historyVisibility
|
||||||
|
pos, err := d.OutputEvents.InsertEvent(
|
||||||
|
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
|
||||||
|
}
|
||||||
|
pduPosition = pos
|
||||||
|
var topoPosition types.StreamPosition
|
||||||
|
if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
|
||||||
|
return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
|
||||||
|
return fmt.Errorf("d.handleBackwardExtremities: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
|
||||||
|
// Nothing to do, the event may have just been a message event.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for i := range addStateEvents {
|
||||||
|
addStateEvents[i].Visibility = historyVisibility
|
||||||
|
}
|
||||||
|
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition)
|
||||||
|
})
|
||||||
|
|
||||||
|
return pduPosition, returnErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function should always be called within a sqlutil.Writer for safety in SQLite.
|
||||||
|
func (d *Database) updateRoomState(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
removedEventIDs []string,
|
||||||
|
addedEvents []*gomatrixserverlib.HeaderedEvent,
|
||||||
|
pduPosition types.StreamPosition,
|
||||||
|
topoPosition types.StreamPosition,
|
||||||
|
) error {
|
||||||
|
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
|
||||||
|
for _, eventID := range removedEventIDs {
|
||||||
|
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
|
||||||
|
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range addedEvents {
|
||||||
|
if event.StateKey() == nil {
|
||||||
|
// ignore non state events
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var membership *string
|
||||||
|
if event.Type() == "m.room.member" {
|
||||||
|
value, err := event.Membership()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("event.Membership: %w", err)
|
||||||
|
}
|
||||||
|
membership = &value
|
||||||
|
if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil {
|
||||||
|
return fmt.Errorf("d.Memberships.UpsertMembership: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
|
||||||
|
return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetFilter(
|
||||||
|
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
|
||||||
|
) error {
|
||||||
|
return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) PutFilter(
|
||||||
|
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
|
||||||
|
) (string, error) {
|
||||||
|
var filterID string
|
||||||
|
var err error
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
filterID, err = d.Filter.InsertFilter(ctx, txn, filter, localpart)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return filterID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error {
|
||||||
|
redactedEvents, err := d.Events(ctx, []string{redactedEventID})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(redactedEvents) == 0 {
|
||||||
|
logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
eventToRedact := redactedEvents[0].Unwrap()
|
||||||
|
redactionEvent := redactedBecause.Unwrap()
|
||||||
|
if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newEvent := eventToRedact.Headered(redactedBecause.RoomVersion)
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
|
||||||
|
// Returns a map of room ID to list of events.
|
||||||
|
func (d *Database) fetchStateEvents(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
roomIDToEventIDSet map[string]map[string]bool,
|
||||||
|
eventIDToEvent map[string]types.StreamEvent,
|
||||||
|
) (map[string][]types.StreamEvent, error) {
|
||||||
|
stateBetween := make(map[string][]types.StreamEvent)
|
||||||
|
missingEvents := make(map[string][]string)
|
||||||
|
for roomID, ids := range roomIDToEventIDSet {
|
||||||
|
events := stateBetween[roomID]
|
||||||
|
for id, need := range ids {
|
||||||
|
if !need {
|
||||||
|
continue // deleted state
|
||||||
|
}
|
||||||
|
e, ok := eventIDToEvent[id]
|
||||||
|
if ok {
|
||||||
|
events = append(events, e)
|
||||||
|
} else {
|
||||||
|
m := missingEvents[roomID]
|
||||||
|
m = append(m, id)
|
||||||
|
missingEvents[roomID] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stateBetween[roomID] = events
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(missingEvents) > 0 {
|
||||||
|
// This happens when add_state_ids has an event ID which is not in the provided range.
|
||||||
|
// We need to explicitly fetch them.
|
||||||
|
allMissingEventIDs := []string{}
|
||||||
|
for _, missingEvIDs := range missingEvents {
|
||||||
|
allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...)
|
||||||
|
}
|
||||||
|
evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// we know we got them all otherwise an error would've been returned, so just loop the events
|
||||||
|
for _, ev := range evs {
|
||||||
|
roomID := ev.RoomID()
|
||||||
|
stateBetween[roomID] = append(stateBetween[roomID], ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stateBetween, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) fetchMissingStateEvents(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) ([]types.StreamEvent, error) {
|
||||||
|
// Fetch from the events table first so we pick up the stream ID for the
|
||||||
|
// event.
|
||||||
|
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
have := map[string]bool{}
|
||||||
|
for _, event := range events {
|
||||||
|
have[event.EventID()] = true
|
||||||
|
}
|
||||||
|
var missing []string
|
||||||
|
for _, eventID := range eventIDs {
|
||||||
|
if !have[eventID] {
|
||||||
|
missing = append(missing, eventID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(missing) == 0 {
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If they are missing from the events table then they should be state
|
||||||
|
// events that we received from outside the main event stream.
|
||||||
|
// These should be in the room state table.
|
||||||
|
stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(stateEvents) != len(missing) {
|
||||||
|
logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing))
|
||||||
|
|
||||||
|
// TODO: Why is this happening? It's probably the roomserver. Uncomment
|
||||||
|
// this error again when we work out what it is and fix it, otherwise we
|
||||||
|
// just end up returning lots of 500s to the client and that breaks
|
||||||
|
// pretty much everything, rather than just sending what we have.
|
||||||
|
//return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing))
|
||||||
|
}
|
||||||
|
events = append(events, stateEvents...)
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) StoreNewSendForDeviceMessage(
|
||||||
|
ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
|
||||||
|
) (newPos types.StreamPosition, err error) {
|
||||||
|
j, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
|
||||||
|
// that we don't lock the table for writes in more than one place.
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
newPos, err = d.SendToDevice.InsertSendToDeviceMessage(
|
||||||
|
ctx, txn, userID, deviceID, string(j),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return newPos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) CleanSendToDeviceUpdates(
|
||||||
|
ctx context.Context,
|
||||||
|
userID, deviceID string, before types.StreamPosition,
|
||||||
|
) (err error) {
|
||||||
|
if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before)
|
||||||
|
}); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
|
||||||
|
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
|
||||||
|
func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) {
|
||||||
|
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
membership, err := ev.Membership()
|
||||||
|
if err != nil {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str
|
||||||
|
return membership, prevMembership
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreReceipt stores user receipts
|
||||||
|
func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, txn, userID, roomID, notificationCount, highlightCount)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
|
||||||
|
}
|
||||||
|
func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
|
||||||
|
return d.Ignores.SelectIgnores(ctx, nil, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Ignores.UpsertIgnores(ctx, txn, userID, ignores)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
|
||||||
|
var pos types.StreamPosition
|
||||||
|
var err error
|
||||||
|
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
pos, err = d.Presence.UpsertPresence(ctx, txn, userID, statusMsg, presence, lastActiveTS, fromSync)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return pos, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
|
||||||
|
return d.Presence.GetPresenceForUser(ctx, nil, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
|
||||||
|
return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{
|
||||||
|
gomatrixserverlib.MRoomName,
|
||||||
|
gomatrixserverlib.MRoomTopic,
|
||||||
|
"m.room.message",
|
||||||
|
})
|
||||||
|
}
|
574
syncapi/storage/shared/storage_sync.go
Normal file
574
syncapi/storage/shared/storage_sync.go
Normal file
|
@ -0,0 +1,574 @@
|
||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DatabaseTransaction struct {
|
||||||
|
*Database
|
||||||
|
txn *sql.Tx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) Commit() error {
|
||||||
|
if d.txn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return d.txn.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) Rollback() error {
|
||||||
|
if d.txn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return d.txn.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) {
|
||||||
|
id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err)
|
||||||
|
}
|
||||||
|
return types.StreamPosition(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) {
|
||||||
|
id, err := d.Receipts.SelectMaxReceiptID(ctx, d.txn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err)
|
||||||
|
}
|
||||||
|
return types.StreamPosition(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) {
|
||||||
|
id, err := d.Invites.SelectMaxInviteID(ctx, d.txn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err)
|
||||||
|
}
|
||||||
|
return types.StreamPosition(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
|
||||||
|
id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, d.txn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
|
||||||
|
}
|
||||||
|
return types.StreamPosition(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
|
||||||
|
id, err := d.AccountData.SelectMaxAccountDataID(ctx, d.txn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err)
|
||||||
|
}
|
||||||
|
return types.StreamPosition(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
|
||||||
|
id, err := d.NotificationData.SelectMaxID(ctx, d.txn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
|
||||||
|
}
|
||||||
|
return types.StreamPosition(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) {
|
||||||
|
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.txn, userID, membership)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) {
|
||||||
|
return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
|
||||||
|
return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
|
||||||
|
return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
|
||||||
|
return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
|
||||||
|
return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) {
|
||||||
|
return d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, deviceID, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
|
||||||
|
return d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events lookups a list of event by their event ID.
|
||||||
|
// Returns a list of events matching the requested IDs found in the database.
|
||||||
|
// If an event is not found in the database then it will be omitted from the list.
|
||||||
|
// Returns an error if there was a problem talking with the database.
|
||||||
|
// Does not include any transaction IDs in the returned events.
|
||||||
|
func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't include a device here as we only include transaction IDs in
|
||||||
|
// incremental syncs.
|
||||||
|
return d.StreamEventsToEvents(nil, streamEvents), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
|
||||||
|
return d.CurrentRoomState.SelectJoinedUsers(ctx, d.txn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
|
||||||
|
return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.txn, roomIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
|
||||||
|
return d.Peeks.SelectPeekingDevices(ctx, d.txn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
|
||||||
|
return d.CurrentRoomState.SelectSharedUsers(ctx, d.txn, userID, otherUserIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) GetStateEvent(
|
||||||
|
ctx context.Context, roomID, evType, stateKey string,
|
||||||
|
) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) GetStateEventsForRoom(
|
||||||
|
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
|
||||||
|
) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||||
|
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountDataInRange returns all account data for a given user inserted or
|
||||||
|
// updated between two given positions
|
||||||
|
// Returns a map following the format data[roomID] = []dataTypes
|
||||||
|
// If no data is retrieved, returns an empty map
|
||||||
|
// If there was an issue with the retrieval, returns an error
|
||||||
|
func (d *DatabaseTransaction) GetAccountDataInRange(
|
||||||
|
ctx context.Context, userID string, r types.Range,
|
||||||
|
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
||||||
|
) (map[string][]string, types.StreamPosition, error) {
|
||||||
|
return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) GetEventsInTopologicalRange(
|
||||||
|
ctx context.Context,
|
||||||
|
from, to *types.TopologyToken,
|
||||||
|
roomID string,
|
||||||
|
filter *gomatrixserverlib.RoomEventFilter,
|
||||||
|
backwardOrdering bool,
|
||||||
|
) (events []types.StreamEvent, err error) {
|
||||||
|
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
|
||||||
|
if backwardOrdering {
|
||||||
|
// Backward ordering means the 'from' token has a higher depth than the 'to' token
|
||||||
|
minDepth = to.Depth
|
||||||
|
maxDepth = from.Depth
|
||||||
|
// for cases where we have say 5 events with the same depth, the TopologyToken needs to
|
||||||
|
// know which of the 5 the client has seen. This is done by using the PDU position.
|
||||||
|
// Events with the same maxDepth but less than this PDU position will be returned.
|
||||||
|
maxStreamPosForMaxDepth = from.PDUPosition
|
||||||
|
} else {
|
||||||
|
// Forward ordering means the 'from' token has a lower depth than the 'to' token.
|
||||||
|
minDepth = from.Depth
|
||||||
|
maxDepth = to.Depth
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select the event IDs from the defined range.
|
||||||
|
var eIDs []string
|
||||||
|
eIDs, err = d.Topology.SelectEventIDsInRange(
|
||||||
|
ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve the events' contents using their IDs.
|
||||||
|
events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) BackwardExtremitiesForRoom(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) (backwardExtremities map[string][]string, err error) {
|
||||||
|
return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) MaxTopologicalPosition(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) (types.TopologyToken, error) {
|
||||||
|
depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return types.TopologyToken{}, err
|
||||||
|
}
|
||||||
|
return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) EventPositionInTopology(
|
||||||
|
ctx context.Context, eventID string,
|
||||||
|
) (types.TopologyToken, error) {
|
||||||
|
depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
|
||||||
|
if err != nil {
|
||||||
|
return types.TopologyToken{}, err
|
||||||
|
}
|
||||||
|
return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) StreamToTopologicalPosition(
|
||||||
|
ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
|
||||||
|
) (types.TopologyToken, error) {
|
||||||
|
topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.txn, roomID, streamPos, backwardOrdering)
|
||||||
|
switch {
|
||||||
|
case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward
|
||||||
|
return types.TopologyToken{PDUPosition: streamPos}, nil
|
||||||
|
case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward
|
||||||
|
topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err)
|
||||||
|
}
|
||||||
|
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
|
||||||
|
case err != nil: // some other error happened
|
||||||
|
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err)
|
||||||
|
default:
|
||||||
|
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
|
||||||
|
// oldest event in the room's topology.
|
||||||
|
func (d *DatabaseTransaction) GetBackwardTopologyPos(
|
||||||
|
ctx context.Context,
|
||||||
|
events []types.StreamEvent,
|
||||||
|
) (types.TopologyToken, error) {
|
||||||
|
zeroToken := types.TopologyToken{}
|
||||||
|
if len(events) == 0 {
|
||||||
|
return zeroToken, nil
|
||||||
|
}
|
||||||
|
pos, spos, err := d.Topology.SelectPositionInTopology(ctx, d.txn, events[0].EventID())
|
||||||
|
if err != nil {
|
||||||
|
return zeroToken, err
|
||||||
|
}
|
||||||
|
tok := types.TopologyToken{Depth: pos, PDUPosition: spos}
|
||||||
|
tok.Decrement()
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStateDeltas returns the state deltas between fromPos and toPos,
|
||||||
|
// exclusive of oldPos, inclusive of newPos, for the rooms in which
|
||||||
|
// the user has new membership events.
|
||||||
|
// A list of joined room IDs is also returned in case the caller needs it.
|
||||||
|
func (d *DatabaseTransaction) GetStateDeltas(
|
||||||
|
ctx context.Context, device *userapi.Device,
|
||||||
|
r types.Range, userID string,
|
||||||
|
stateFilter *gomatrixserverlib.StateFilter,
|
||||||
|
) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
|
||||||
|
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
|
||||||
|
// - Get membership list changes for this user in this sync response
|
||||||
|
// - For each room which has membership list changes:
|
||||||
|
// * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO).
|
||||||
|
// If it is, then we need to send the full room state down (and 'limited' is always true).
|
||||||
|
// * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
|
||||||
|
// * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
|
||||||
|
// - Get all CURRENTLY joined rooms, and add them to 'joined' block.
|
||||||
|
|
||||||
|
// Look up all memberships for the user. We only care about rooms that a
|
||||||
|
// user has ever interacted with — joined to, kicked/banned from, left.
|
||||||
|
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
allRoomIDs := make([]string, 0, len(memberships))
|
||||||
|
joinedRoomIDs := make([]string, 0, len(memberships))
|
||||||
|
for roomID, membership := range memberships {
|
||||||
|
allRoomIDs = append(allRoomIDs, roomID)
|
||||||
|
if membership == gomatrixserverlib.Join {
|
||||||
|
joinedRoomIDs = append(joinedRoomIDs, roomID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get all the state events ever (i.e. for all available rooms) between these two positions
|
||||||
|
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// find out which rooms this user is peeking, if any.
|
||||||
|
// We do this before joins so any peeks get overwritten
|
||||||
|
peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// add peek blocks
|
||||||
|
for _, peek := range peeks {
|
||||||
|
if peek.New {
|
||||||
|
// send full room state down instead of a delta
|
||||||
|
var s []types.StreamEvent
|
||||||
|
s, err = d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
state[peek.RoomID] = s
|
||||||
|
}
|
||||||
|
if !peek.Deleted {
|
||||||
|
deltas = append(deltas, types.StateDelta{
|
||||||
|
Membership: gomatrixserverlib.Peek,
|
||||||
|
StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]),
|
||||||
|
RoomID: peek.RoomID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle newly joined rooms and non-joined rooms
|
||||||
|
newlyJoinedRooms := make(map[string]bool, len(state))
|
||||||
|
for roomID, stateStreamEvents := range state {
|
||||||
|
for _, ev := range stateStreamEvents {
|
||||||
|
if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" {
|
||||||
|
if membership == gomatrixserverlib.Join && prevMembership != membership {
|
||||||
|
// send full room state down instead of a delta
|
||||||
|
var s []types.StreamEvent
|
||||||
|
s, err = d.currentStateStreamEventsForRoom(ctx, roomID, stateFilter)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
state[roomID] = s
|
||||||
|
newlyJoinedRooms[roomID] = true
|
||||||
|
continue // we'll add this room in when we do joined rooms
|
||||||
|
}
|
||||||
|
|
||||||
|
deltas = append(deltas, types.StateDelta{
|
||||||
|
Membership: membership,
|
||||||
|
MembershipPos: ev.StreamPosition,
|
||||||
|
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
|
||||||
|
RoomID: roomID,
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add in currently joined rooms
|
||||||
|
for _, joinedRoomID := range joinedRoomIDs {
|
||||||
|
deltas = append(deltas, types.StateDelta{
|
||||||
|
Membership: gomatrixserverlib.Join,
|
||||||
|
StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
|
||||||
|
RoomID: joinedRoomID,
|
||||||
|
NewlyJoined: newlyJoinedRooms[joinedRoomID],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return deltas, joinedRoomIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
|
||||||
|
// requests with full_state=true.
|
||||||
|
// Fetches full state for all joined rooms and uses selectStateInRange to get
|
||||||
|
// updates for other rooms.
|
||||||
|
func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
|
||||||
|
ctx context.Context, device *userapi.Device,
|
||||||
|
r types.Range, userID string,
|
||||||
|
stateFilter *gomatrixserverlib.StateFilter,
|
||||||
|
) ([]types.StateDelta, []string, error) {
|
||||||
|
// Look up all memberships for the user. We only care about rooms that a
|
||||||
|
// user has ever interacted with — joined to, kicked/banned from, left.
|
||||||
|
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
allRoomIDs := make([]string, 0, len(memberships))
|
||||||
|
joinedRoomIDs := make([]string, 0, len(memberships))
|
||||||
|
for roomID, membership := range memberships {
|
||||||
|
allRoomIDs = append(allRoomIDs, roomID)
|
||||||
|
if membership == gomatrixserverlib.Join {
|
||||||
|
joinedRoomIDs = append(joinedRoomIDs, roomID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a reasonable initial capacity
|
||||||
|
deltas := make(map[string]types.StateDelta)
|
||||||
|
|
||||||
|
peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add full states for all peeking rooms
|
||||||
|
for _, peek := range peeks {
|
||||||
|
if !peek.Deleted {
|
||||||
|
s, stateErr := d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
|
||||||
|
if stateErr != nil {
|
||||||
|
if stateErr == sql.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, nil, stateErr
|
||||||
|
}
|
||||||
|
deltas[peek.RoomID] = types.StateDelta{
|
||||||
|
Membership: gomatrixserverlib.Peek,
|
||||||
|
StateEvents: d.StreamEventsToEvents(device, s),
|
||||||
|
RoomID: peek.RoomID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all the state events ever between these two positions
|
||||||
|
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for roomID, stateStreamEvents := range state {
|
||||||
|
for _, ev := range stateStreamEvents {
|
||||||
|
if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" {
|
||||||
|
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
|
||||||
|
deltas[roomID] = types.StateDelta{
|
||||||
|
Membership: membership,
|
||||||
|
MembershipPos: ev.StreamPosition,
|
||||||
|
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
|
||||||
|
RoomID: roomID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add full states for all joined rooms
|
||||||
|
for _, joinedRoomID := range joinedRoomIDs {
|
||||||
|
s, stateErr := d.currentStateStreamEventsForRoom(ctx, joinedRoomID, stateFilter)
|
||||||
|
if stateErr != nil {
|
||||||
|
if stateErr == sql.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, nil, stateErr
|
||||||
|
}
|
||||||
|
deltas[joinedRoomID] = types.StateDelta{
|
||||||
|
Membership: gomatrixserverlib.Join,
|
||||||
|
StateEvents: d.StreamEventsToEvents(device, s),
|
||||||
|
RoomID: joinedRoomID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a response array.
|
||||||
|
result := make([]types.StateDelta, len(deltas))
|
||||||
|
i := 0
|
||||||
|
for _, delta := range deltas {
|
||||||
|
result[i] = delta
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, joinedRoomIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) currentStateStreamEventsForRoom(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
stateFilter *gomatrixserverlib.StateFilter,
|
||||||
|
) ([]types.StreamEvent, error) {
|
||||||
|
allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s := make([]types.StreamEvent, len(allState))
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0}
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) SendToDeviceUpdatesForSync(
|
||||||
|
ctx context.Context,
|
||||||
|
userID, deviceID string,
|
||||||
|
from, to types.StreamPosition,
|
||||||
|
) (types.StreamPosition, []types.SendToDeviceEvent, error) {
|
||||||
|
// First of all, get our send-to-device updates for this user.
|
||||||
|
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, d.txn, userID, deviceID, from, to)
|
||||||
|
if err != nil {
|
||||||
|
return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
||||||
|
}
|
||||||
|
// If there's nothing to do then stop here.
|
||||||
|
if len(events) == 0 {
|
||||||
|
return to, nil, nil
|
||||||
|
}
|
||||||
|
return lastPos, events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) {
|
||||||
|
_, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
|
||||||
|
return receipts, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
|
||||||
|
roomIDs := make([]string, 0, len(rooms))
|
||||||
|
for roomID, membership := range rooms {
|
||||||
|
if membership != gomatrixserverlib.Join {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
|
||||||
|
return d.Presence.GetPresenceForUser(ctx, d.txn, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
|
||||||
|
return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
|
||||||
|
return d.Presence.GetMaxPresenceID(ctx, d.txn)
|
||||||
|
}
|
File diff suppressed because it is too large
Load diff
|
@ -367,7 +367,13 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
|
||||||
for start < len(eventIDs) {
|
for start < len(eventIDs) {
|
||||||
n := minOfInts(len(eventIDs)-start, 999)
|
n := minOfInts(len(eventIDs)-start, 999)
|
||||||
query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1)
|
query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1)
|
||||||
rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if txn == nil {
|
||||||
|
rows, err = s.db.QueryContext(ctx, query, iEventIDs[start:start+n]...)
|
||||||
|
} else {
|
||||||
|
rows, err = txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,7 @@ const deleteInviteEventSQL = "" +
|
||||||
"UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false"
|
"UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false"
|
||||||
|
|
||||||
const selectInviteEventsInRangeSQL = "" +
|
const selectInviteEventsInRangeSQL = "" +
|
||||||
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
|
"SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" +
|
||||||
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
|
||||||
" ORDER BY id DESC"
|
" ORDER BY id DESC"
|
||||||
|
|
||||||
|
@ -132,23 +132,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
|
||||||
// active invites for the target user ID in the supplied range.
|
// active invites for the target user ID in the supplied range.
|
||||||
func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
||||||
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
|
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
|
||||||
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
|
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
|
||||||
|
var lastPos types.StreamPosition
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
|
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, lastPos, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
|
||||||
result := map[string]*gomatrixserverlib.HeaderedEvent{}
|
result := map[string]*gomatrixserverlib.HeaderedEvent{}
|
||||||
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
|
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
|
id types.StreamPosition
|
||||||
roomID string
|
roomID string
|
||||||
eventJSON []byte
|
eventJSON []byte
|
||||||
deleted bool
|
deleted bool
|
||||||
)
|
)
|
||||||
if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
|
if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, lastPos, err
|
||||||
|
}
|
||||||
|
if id > lastPos {
|
||||||
|
lastPos = id
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we have seen this room before, it has a higher stream position and hence takes priority
|
// if we have seen this room before, it has a higher stream position and hence takes priority
|
||||||
|
@ -161,15 +166,19 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
||||||
|
|
||||||
var event *gomatrixserverlib.HeaderedEvent
|
var event *gomatrixserverlib.HeaderedEvent
|
||||||
if err := json.Unmarshal(eventJSON, &event); err != nil {
|
if err := json.Unmarshal(eventJSON, &event); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, lastPos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if deleted {
|
if deleted {
|
||||||
retired[roomID] = event
|
retired[roomID] = event
|
||||||
} else {
|
} else {
|
||||||
result[roomID] = event
|
result[roomID] = event
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, retired, nil
|
if lastPos == 0 {
|
||||||
|
lastPos = r.To
|
||||||
|
}
|
||||||
|
return result, retired, lastPos, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteEventsStatements) SelectMaxInviteID(
|
func (s *inviteEventsStatements) SelectMaxInviteID(
|
||||||
|
|
|
@ -49,6 +49,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
|
||||||
return &d, nil
|
return &d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *SyncServerDatasource) NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error) {
|
||||||
|
return &shared.DatabaseTransaction{
|
||||||
|
Database: &d.Database,
|
||||||
|
// not setting a transaction because SQLite doesn't support it
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SyncServerDatasource) NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error) {
|
||||||
|
return &shared.DatabaseTransaction{
|
||||||
|
Database: &d.Database,
|
||||||
|
// not setting a transaction because SQLite doesn't support it
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
|
func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
|
||||||
if err = d.streamID.Prepare(d.db); err != nil {
|
if err = d.streamID.Prepare(d.db); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -60,6 +60,17 @@ func TestWriteEvents(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.DatabaseTransaction)) {
|
||||||
|
snapshot, err := db.NewDatabaseSnapshot(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
f(snapshot)
|
||||||
|
if err := snapshot.Rollback(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// These tests assert basic functionality of RecentEvents for PDUs
|
// These tests assert basic functionality of RecentEvents for PDUs
|
||||||
func TestRecentEventsPDU(t *testing.T) {
|
func TestRecentEventsPDU(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
@ -79,10 +90,13 @@ func TestRecentEventsPDU(t *testing.T) {
|
||||||
// dummy room to make sure SQL queries are filtering on room ID
|
// dummy room to make sure SQL queries are filtering on room ID
|
||||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||||
|
|
||||||
latest, err := db.MaxStreamPositionForPDUs(ctx)
|
var latest types.StreamPosition
|
||||||
if err != nil {
|
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
|
||||||
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
|
var err error
|
||||||
}
|
if latest, err = snapshot.MaxStreamPositionForPDUs(ctx); err != nil {
|
||||||
|
t.Fatal("failed to get MaxStreamPositionForPDUs: %w", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
Name string
|
Name string
|
||||||
|
@ -140,14 +154,19 @@ func TestRecentEventsPDU(t *testing.T) {
|
||||||
tc := testCases[i]
|
tc := testCases[i]
|
||||||
t.Run(tc.Name, func(st *testing.T) {
|
t.Run(tc.Name, func(st *testing.T) {
|
||||||
var filter gomatrixserverlib.RoomEventFilter
|
var filter gomatrixserverlib.RoomEventFilter
|
||||||
|
var gotEvents []types.StreamEvent
|
||||||
|
var limited bool
|
||||||
filter.Limit = tc.Limit
|
filter.Limit = tc.Limit
|
||||||
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
|
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
|
||||||
From: tc.From,
|
var err error
|
||||||
To: tc.To,
|
gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{
|
||||||
}, &filter, !tc.ReverseOrder, true)
|
From: tc.From,
|
||||||
if err != nil {
|
To: tc.To,
|
||||||
st.Fatalf("failed to do sync: %s", err)
|
}, &filter, !tc.ReverseOrder, true)
|
||||||
}
|
if err != nil {
|
||||||
|
st.Fatalf("failed to do sync: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
if limited != tc.WantLimited {
|
if limited != tc.WantLimited {
|
||||||
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
|
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
|
||||||
}
|
}
|
||||||
|
@ -178,22 +197,24 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||||
events := r.Events()
|
events := r.Events()
|
||||||
_ = MustWriteEvents(t, db, events)
|
_ = MustWriteEvents(t, db, events)
|
||||||
|
|
||||||
from, err := db.MaxTopologicalPosition(ctx, r.ID)
|
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
|
||||||
if err != nil {
|
from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
|
||||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
if err != nil {
|
||||||
}
|
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||||
t.Logf("max topo pos = %+v", from)
|
}
|
||||||
// head towards the beginning of time
|
t.Logf("max topo pos = %+v", from)
|
||||||
to := types.TopologyToken{}
|
// head towards the beginning of time
|
||||||
|
to := types.TopologyToken{}
|
||||||
|
|
||||||
// backpaginate 5 messages starting at the latest position.
|
// backpaginate 5 messages starting at the latest position.
|
||||||
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
|
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
|
||||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
||||||
}
|
}
|
||||||
gots := db.StreamEventsToEvents(nil, paginatedEvents)
|
gots := snapshot.StreamEventsToEvents(nil, paginatedEvents)
|
||||||
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -414,13 +435,16 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
defer closeBase()
|
defer closeBase()
|
||||||
// At this point there should be no messages. We haven't sent anything
|
// At this point there should be no messages. We haven't sent anything
|
||||||
// yet.
|
// yet.
|
||||||
_, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
|
|
||||||
if err != nil {
|
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
|
||||||
t.Fatal(err)
|
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
|
||||||
}
|
if err != nil {
|
||||||
if len(events) != 0 {
|
t.Fatal(err)
|
||||||
t.Fatal("first call should have no updates")
|
}
|
||||||
}
|
if len(events) != 0 {
|
||||||
|
t.Fatal("first call should have no updates")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Try sending a message.
|
// Try sending a message.
|
||||||
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
|
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
|
||||||
|
@ -432,51 +456,58 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point we should get exactly one message. We're sending the sync position
|
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
|
||||||
// that we were given from the update and the send-to-device update will be updated
|
// At this point we should get exactly one message. We're sending the sync position
|
||||||
// in the database to reflect that this was the sync position we sent the message at.
|
// that we were given from the update and the send-to-device update will be updated
|
||||||
streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
|
// in the database to reflect that this was the sync position we sent the message at.
|
||||||
if err != nil {
|
var events []types.SendToDeviceEvent
|
||||||
t.Fatal(err)
|
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
|
||||||
}
|
if err != nil {
|
||||||
if count := len(events); count != 1 {
|
t.Fatal(err)
|
||||||
t.Fatalf("second call should have one update, got %d", count)
|
}
|
||||||
}
|
if count := len(events); count != 1 {
|
||||||
|
t.Fatalf("second call should have one update, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point we should still have one message because we haven't progressed the
|
||||||
|
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
||||||
|
// with the same position.
|
||||||
|
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("third call should have one update still")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// At this point we should still have one message because we haven't progressed the
|
|
||||||
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
|
||||||
// with the same position.
|
|
||||||
streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if len(events) != 1 {
|
|
||||||
t.Fatal("third call should have one update still")
|
|
||||||
}
|
|
||||||
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
|
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point we should now have no updates, because we've progressed the sync
|
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
|
||||||
// position. Therefore the update from before will not be sent again.
|
// At this point we should now have no updates, because we've progressed the sync
|
||||||
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
|
// position. Therefore the update from before will not be sent again.
|
||||||
if err != nil {
|
var events []types.SendToDeviceEvent
|
||||||
t.Fatal(err)
|
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
|
||||||
}
|
if err != nil {
|
||||||
if len(events) != 0 {
|
t.Fatal(err)
|
||||||
t.Fatal("fourth call should have no updates")
|
}
|
||||||
}
|
if len(events) != 0 {
|
||||||
|
t.Fatal("fourth call should have no updates")
|
||||||
|
}
|
||||||
|
|
||||||
// At this point we should still have no updates, because no new updates have been
|
// At this point we should still have no updates, because no new updates have been
|
||||||
// sent.
|
// sent.
|
||||||
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
|
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(events) != 0 {
|
if len(events) != 0 {
|
||||||
t.Fatal("fifth call should have no updates")
|
t.Fatal("fifth call should have no updates")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Send some more messages and verify the ordering is correct ("in order of arrival")
|
// Send some more messages and verify the ordering is correct ("in order of arrival")
|
||||||
var lastPos types.StreamPosition = 0
|
var lastPos types.StreamPosition = 0
|
||||||
|
@ -492,18 +523,20 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
lastPos = streamPos
|
lastPos = streamPos
|
||||||
}
|
}
|
||||||
|
|
||||||
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
|
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
|
||||||
if err != nil {
|
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
|
||||||
t.Fatalf("unable to get events: %v", err)
|
if err != nil {
|
||||||
}
|
t.Fatalf("unable to get events: %v", err)
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
|
|
||||||
got := events[i].Content
|
|
||||||
if !bytes.Equal(got, want) {
|
|
||||||
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
|
||||||
|
got := events[i].Content
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ type Invites interface {
|
||||||
DeleteInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string) (types.StreamPosition, error)
|
DeleteInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string) (types.StreamPosition, error)
|
||||||
// SelectInviteEventsInRange returns a map of room ID to invite events. If multiple invite/retired invites exist in the given range, return the latest value
|
// SelectInviteEventsInRange returns a map of room ID to invite events. If multiple invite/retired invites exist in the given range, return the latest value
|
||||||
// for the room.
|
// for the room.
|
||||||
SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, err error)
|
SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error)
|
||||||
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,22 +5,25 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AccountDataStreamProvider struct {
|
type AccountDataStreamProvider struct {
|
||||||
StreamProvider
|
DefaultStreamProvider
|
||||||
userAPI userapi.SyncUserAPI
|
userAPI userapi.SyncUserAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AccountDataStreamProvider) Setup() {
|
func (p *AccountDataStreamProvider) Setup(
|
||||||
p.StreamProvider.Setup()
|
ctx context.Context, snapshot storage.DatabaseTransaction,
|
||||||
|
) {
|
||||||
|
p.DefaultStreamProvider.Setup(ctx, snapshot)
|
||||||
|
|
||||||
p.latestMutex.Lock()
|
p.latestMutex.Lock()
|
||||||
defer p.latestMutex.Unlock()
|
defer p.latestMutex.Unlock()
|
||||||
|
|
||||||
id, err := p.DB.MaxStreamPositionForAccountData(context.Background())
|
id, err := snapshot.MaxStreamPositionForAccountData(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -29,13 +32,15 @@ func (p *AccountDataStreamProvider) Setup() {
|
||||||
|
|
||||||
func (p *AccountDataStreamProvider) CompleteSync(
|
func (p *AccountDataStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
|
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AccountDataStreamProvider) IncrementalSync(
|
func (p *AccountDataStreamProvider) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
|
@ -44,7 +49,7 @@ func (p *AccountDataStreamProvider) IncrementalSync(
|
||||||
To: to,
|
To: to,
|
||||||
}
|
}
|
||||||
|
|
||||||
dataTypes, pos, err := p.DB.GetAccountDataInRange(
|
dataTypes, pos, err := snapshot.GetAccountDataInRange(
|
||||||
ctx, req.Device.UserID, r, &req.Filter.AccountData,
|
ctx, req.Device.UserID, r, &req.Filter.AccountData,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -6,17 +6,19 @@ import (
|
||||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/syncapi/internal"
|
"github.com/matrix-org/dendrite/syncapi/internal"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DeviceListStreamProvider struct {
|
type DeviceListStreamProvider struct {
|
||||||
StreamProvider
|
DefaultStreamProvider
|
||||||
rsAPI api.SyncRoomserverAPI
|
rsAPI api.SyncRoomserverAPI
|
||||||
keyAPI keyapi.SyncKeyAPI
|
keyAPI keyapi.SyncKeyAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DeviceListStreamProvider) CompleteSync(
|
func (p *DeviceListStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
return p.LatestPosition(ctx)
|
return p.LatestPosition(ctx)
|
||||||
|
@ -24,11 +26,12 @@ func (p *DeviceListStreamProvider) CompleteSync(
|
||||||
|
|
||||||
func (p *DeviceListStreamProvider) IncrementalSync(
|
func (p *DeviceListStreamProvider) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
var err error
|
var err error
|
||||||
to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
|
to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
|
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
|
||||||
return from
|
return from
|
||||||
|
|
|
@ -9,20 +9,23 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type InviteStreamProvider struct {
|
type InviteStreamProvider struct {
|
||||||
StreamProvider
|
DefaultStreamProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *InviteStreamProvider) Setup() {
|
func (p *InviteStreamProvider) Setup(
|
||||||
p.StreamProvider.Setup()
|
ctx context.Context, snapshot storage.DatabaseTransaction,
|
||||||
|
) {
|
||||||
|
p.DefaultStreamProvider.Setup(ctx, snapshot)
|
||||||
|
|
||||||
p.latestMutex.Lock()
|
p.latestMutex.Lock()
|
||||||
defer p.latestMutex.Unlock()
|
defer p.latestMutex.Unlock()
|
||||||
|
|
||||||
id, err := p.DB.MaxStreamPositionForInvites(context.Background())
|
id, err := snapshot.MaxStreamPositionForInvites(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -31,13 +34,15 @@ func (p *InviteStreamProvider) Setup() {
|
||||||
|
|
||||||
func (p *InviteStreamProvider) CompleteSync(
|
func (p *InviteStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
|
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *InviteStreamProvider) IncrementalSync(
|
func (p *InviteStreamProvider) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
|
@ -46,7 +51,7 @@ func (p *InviteStreamProvider) IncrementalSync(
|
||||||
To: to,
|
To: to,
|
||||||
}
|
}
|
||||||
|
|
||||||
invites, retiredInvites, err := p.DB.InviteEventsInRange(
|
invites, retiredInvites, maxID, err := snapshot.InviteEventsInRange(
|
||||||
ctx, req.Device.UserID, r,
|
ctx, req.Device.UserID, r,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -86,5 +91,5 @@ func (p *InviteStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return to
|
return maxID
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,17 +3,23 @@ package streams
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type NotificationDataStreamProvider struct {
|
type NotificationDataStreamProvider struct {
|
||||||
StreamProvider
|
DefaultStreamProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *NotificationDataStreamProvider) Setup() {
|
func (p *NotificationDataStreamProvider) Setup(
|
||||||
p.StreamProvider.Setup()
|
ctx context.Context, snapshot storage.DatabaseTransaction,
|
||||||
|
) {
|
||||||
|
p.DefaultStreamProvider.Setup(ctx, snapshot)
|
||||||
|
|
||||||
id, err := p.DB.MaxStreamPositionForNotificationData(context.Background())
|
p.latestMutex.Lock()
|
||||||
|
defer p.latestMutex.Unlock()
|
||||||
|
|
||||||
|
id, err := snapshot.MaxStreamPositionForNotificationData(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -22,20 +28,22 @@ func (p *NotificationDataStreamProvider) Setup() {
|
||||||
|
|
||||||
func (p *NotificationDataStreamProvider) CompleteSync(
|
func (p *NotificationDataStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
|
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *NotificationDataStreamProvider) IncrementalSync(
|
func (p *NotificationDataStreamProvider) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
from, _ types.StreamPosition,
|
from, _ types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
// Get the unread notifications for rooms in our join response.
|
// Get the unread notifications for rooms in our join response.
|
||||||
// This is to ensure clients always have an unread notification section
|
// This is to ensure clients always have an unread notification section
|
||||||
// and can display the correct numbers.
|
// and can display the correct numbers.
|
||||||
countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
|
countsByRoom, err := snapshot.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
|
req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
|
||||||
return from
|
return from
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
@ -18,7 +17,6 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"go.uber.org/atomic"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
)
|
)
|
||||||
|
@ -33,44 +31,23 @@ const PDU_STREAM_WORKERS = 256
|
||||||
const PDU_STREAM_QUEUESIZE = PDU_STREAM_WORKERS * 8
|
const PDU_STREAM_QUEUESIZE = PDU_STREAM_WORKERS * 8
|
||||||
|
|
||||||
type PDUStreamProvider struct {
|
type PDUStreamProvider struct {
|
||||||
StreamProvider
|
DefaultStreamProvider
|
||||||
|
|
||||||
tasks chan func()
|
|
||||||
workers atomic.Int32
|
|
||||||
// userID+deviceID -> lazy loading cache
|
// userID+deviceID -> lazy loading cache
|
||||||
lazyLoadCache caching.LazyLoadCache
|
lazyLoadCache caching.LazyLoadCache
|
||||||
rsAPI roomserverAPI.SyncRoomserverAPI
|
rsAPI roomserverAPI.SyncRoomserverAPI
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PDUStreamProvider) worker() {
|
func (p *PDUStreamProvider) Setup(
|
||||||
defer p.workers.Dec()
|
ctx context.Context, snapshot storage.DatabaseTransaction,
|
||||||
for {
|
) {
|
||||||
select {
|
p.DefaultStreamProvider.Setup(ctx, snapshot)
|
||||||
case f := <-p.tasks:
|
|
||||||
f()
|
|
||||||
case <-time.After(time.Second * 10):
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PDUStreamProvider) queue(f func()) {
|
|
||||||
if p.workers.Load() < PDU_STREAM_WORKERS {
|
|
||||||
p.workers.Inc()
|
|
||||||
go p.worker()
|
|
||||||
}
|
|
||||||
p.tasks <- f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PDUStreamProvider) Setup() {
|
|
||||||
p.StreamProvider.Setup()
|
|
||||||
p.tasks = make(chan func(), PDU_STREAM_QUEUESIZE)
|
|
||||||
|
|
||||||
p.latestMutex.Lock()
|
p.latestMutex.Lock()
|
||||||
defer p.latestMutex.Unlock()
|
defer p.latestMutex.Unlock()
|
||||||
|
|
||||||
id, err := p.DB.MaxStreamPositionForPDUs(context.Background())
|
id, err := snapshot.MaxStreamPositionForPDUs(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -79,6 +56,7 @@ func (p *PDUStreamProvider) Setup() {
|
||||||
|
|
||||||
func (p *PDUStreamProvider) CompleteSync(
|
func (p *PDUStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
from := types.StreamPosition(0)
|
from := types.StreamPosition(0)
|
||||||
|
@ -94,7 +72,7 @@ func (p *PDUStreamProvider) CompleteSync(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract room state and recent events for all rooms the user is joined to.
|
// Extract room state and recent events for all rooms the user is joined to.
|
||||||
joinedRoomIDs, err := p.DB.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join)
|
joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed")
|
req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed")
|
||||||
return from
|
return from
|
||||||
|
@ -103,7 +81,7 @@ func (p *PDUStreamProvider) CompleteSync(
|
||||||
stateFilter := req.Filter.Room.State
|
stateFilter := req.Filter.Room.State
|
||||||
eventFilter := req.Filter.Room.Timeline
|
eventFilter := req.Filter.Room.Timeline
|
||||||
|
|
||||||
if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil {
|
if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil {
|
||||||
req.Log.WithError(err).Error("unable to update event filter with ignored users")
|
req.Log.WithError(err).Error("unable to update event filter with ignored users")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,33 +95,20 @@ func (p *PDUStreamProvider) CompleteSync(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build up a /sync response. Add joined rooms.
|
// Build up a /sync response. Add joined rooms.
|
||||||
var reqMutex sync.Mutex
|
for _, roomID := range joinedRoomIDs {
|
||||||
var reqWaitGroup sync.WaitGroup
|
jr, jerr := p.getJoinResponseForCompleteSync(
|
||||||
reqWaitGroup.Add(len(joinedRoomIDs))
|
ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
|
||||||
for _, room := range joinedRoomIDs {
|
)
|
||||||
roomID := room
|
if jerr != nil {
|
||||||
p.queue(func() {
|
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
|
||||||
defer reqWaitGroup.Done()
|
continue // return from
|
||||||
|
}
|
||||||
jr, jerr := p.getJoinResponseForCompleteSync(
|
req.Response.Rooms.Join[roomID] = *jr
|
||||||
ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
|
req.Rooms[roomID] = gomatrixserverlib.Join
|
||||||
)
|
|
||||||
if jerr != nil {
|
|
||||||
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
reqMutex.Lock()
|
|
||||||
defer reqMutex.Unlock()
|
|
||||||
req.Response.Rooms.Join[roomID] = *jr
|
|
||||||
req.Rooms[roomID] = gomatrixserverlib.Join
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
reqWaitGroup.Wait()
|
|
||||||
|
|
||||||
// Add peeked rooms.
|
// Add peeked rooms.
|
||||||
peeks, err := p.DB.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r)
|
peeks, err := snapshot.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.PeeksInRange failed")
|
req.Log.WithError(err).Error("p.DB.PeeksInRange failed")
|
||||||
return from
|
return from
|
||||||
|
@ -152,11 +117,11 @@ func (p *PDUStreamProvider) CompleteSync(
|
||||||
if !peek.Deleted {
|
if !peek.Deleted {
|
||||||
var jr *types.JoinResponse
|
var jr *types.JoinResponse
|
||||||
jr, err = p.getJoinResponseForCompleteSync(
|
jr, err = p.getJoinResponseForCompleteSync(
|
||||||
ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
|
ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
|
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
|
||||||
return from
|
continue // return from
|
||||||
}
|
}
|
||||||
req.Response.Rooms.Peek[peek.RoomID] = *jr
|
req.Response.Rooms.Peek[peek.RoomID] = *jr
|
||||||
}
|
}
|
||||||
|
@ -167,6 +132,7 @@ func (p *PDUStreamProvider) CompleteSync(
|
||||||
|
|
||||||
func (p *PDUStreamProvider) IncrementalSync(
|
func (p *PDUStreamProvider) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) (newPos types.StreamPosition) {
|
) (newPos types.StreamPosition) {
|
||||||
|
@ -184,12 +150,12 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
eventFilter := req.Filter.Room.Timeline
|
eventFilter := req.Filter.Room.Timeline
|
||||||
|
|
||||||
if req.WantFullState {
|
if req.WantFullState {
|
||||||
if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
|
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
|
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
|
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
|
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -203,7 +169,7 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
return to
|
return to
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil {
|
if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil {
|
||||||
req.Log.WithError(err).Error("unable to update event filter with ignored users")
|
req.Log.WithError(err).Error("unable to update event filter with ignored users")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,7 +188,7 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var pos types.StreamPosition
|
var pos types.StreamPosition
|
||||||
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
|
if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
|
||||||
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
||||||
return to
|
return to
|
||||||
}
|
}
|
||||||
|
@ -244,6 +210,7 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
device *userapi.Device,
|
device *userapi.Device,
|
||||||
r types.Range,
|
r types.Range,
|
||||||
delta types.StateDelta,
|
delta types.StateDelta,
|
||||||
|
@ -260,7 +227,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
// This is all "okay" assuming history_visibility == "shared" which it is by default.
|
// This is all "okay" assuming history_visibility == "shared" which it is by default.
|
||||||
r.To = delta.MembershipPos
|
r.To = delta.MembershipPos
|
||||||
}
|
}
|
||||||
recentStreamEvents, limited, err := p.DB.RecentEvents(
|
recentStreamEvents, limited, err := snapshot.RecentEvents(
|
||||||
ctx, delta.RoomID, r,
|
ctx, delta.RoomID, r,
|
||||||
eventFilter, true, true,
|
eventFilter, true, true,
|
||||||
)
|
)
|
||||||
|
@ -270,9 +237,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
}
|
}
|
||||||
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
|
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
|
||||||
}
|
}
|
||||||
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
|
recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents)
|
||||||
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
|
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
|
||||||
prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents)
|
prevBatch, err := snapshot.GetBackwardTopologyPos(ctx, recentStreamEvents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
|
return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -291,7 +258,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
latestPosition := r.To
|
latestPosition := r.To
|
||||||
updateLatestPosition := func(mostRecentEventID string) {
|
updateLatestPosition := func(mostRecentEventID string) {
|
||||||
var pos types.StreamPosition
|
var pos types.StreamPosition
|
||||||
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
|
if _, pos, err = snapshot.PositionInTopology(ctx, mostRecentEventID); err == nil {
|
||||||
switch {
|
switch {
|
||||||
case r.Backwards && pos < latestPosition:
|
case r.Backwards && pos < latestPosition:
|
||||||
fallthrough
|
fallthrough
|
||||||
|
@ -303,7 +270,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
|
|
||||||
if stateFilter.LazyLoadMembers {
|
if stateFilter.LazyLoadMembers {
|
||||||
delta.StateEvents, err = p.lazyLoadMembers(
|
delta.StateEvents, err = p.lazyLoadMembers(
|
||||||
ctx, delta.RoomID, true, limited, stateFilter,
|
ctx, snapshot, delta.RoomID, true, limited, stateFilter,
|
||||||
device, recentEvents, delta.StateEvents,
|
device, recentEvents, delta.StateEvents,
|
||||||
)
|
)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
@ -320,7 +287,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Applies the history visibility rules
|
// Applies the history visibility rules
|
||||||
events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
|
events, err := applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("unable to apply history visibility filter")
|
logrus.WithError(err).Error("unable to apply history visibility filter")
|
||||||
}
|
}
|
||||||
|
@ -336,7 +303,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
case gomatrixserverlib.Join:
|
case gomatrixserverlib.Join:
|
||||||
jr := types.NewJoinResponse()
|
jr := types.NewJoinResponse()
|
||||||
if hasMembershipChange {
|
if hasMembershipChange {
|
||||||
p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition)
|
p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition)
|
||||||
}
|
}
|
||||||
jr.Timeline.PrevBatch = &prevBatch
|
jr.Timeline.PrevBatch = &prevBatch
|
||||||
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
|
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
|
||||||
|
@ -376,7 +343,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
// sure we always return the required events in the timeline.
|
// sure we always return the required events in the timeline.
|
||||||
func applyHistoryVisibilityFilter(
|
func applyHistoryVisibilityFilter(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db storage.Database,
|
snapshot storage.DatabaseTransaction,
|
||||||
rsAPI roomserverAPI.SyncRoomserverAPI,
|
rsAPI roomserverAPI.SyncRoomserverAPI,
|
||||||
roomID, userID string,
|
roomID, userID string,
|
||||||
limit int,
|
limit int,
|
||||||
|
@ -384,7 +351,7 @@ func applyHistoryVisibilityFilter(
|
||||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
// We need to make sure we always include the latest states events, if they are in the timeline.
|
// We need to make sure we always include the latest states events, if they are in the timeline.
|
||||||
// We grep at least limit * 2 events, to ensure we really get the needed events.
|
// We grep at least limit * 2 events, to ensure we really get the needed events.
|
||||||
stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
|
stateEvents, err := snapshot.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Not a fatal error, we can continue without the stateEvents,
|
// Not a fatal error, we can continue without the stateEvents,
|
||||||
// they are only needed if there are state events in the timeline.
|
// they are only needed if there are state events in the timeline.
|
||||||
|
@ -395,7 +362,7 @@ func applyHistoryVisibilityFilter(
|
||||||
alwaysIncludeIDs[ev.EventID()] = struct{}{}
|
alwaysIncludeIDs[ev.EventID()] = struct{}{}
|
||||||
}
|
}
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
|
events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -408,10 +375,10 @@ func applyHistoryVisibilityFilter(
|
||||||
return events, nil
|
return events, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
|
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
|
||||||
// Work out how many members are in the room.
|
// Work out how many members are in the room.
|
||||||
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
|
joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
|
||||||
invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
|
invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
|
||||||
|
|
||||||
jr.Summary.JoinedMemberCount = &joinedCount
|
jr.Summary.JoinedMemberCount = &joinedCount
|
||||||
jr.Summary.InvitedMemberCount = &invitedCount
|
jr.Summary.InvitedMemberCount = &invitedCount
|
||||||
|
@ -439,7 +406,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
heroes, err := p.DB.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
|
heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -449,6 +416,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe
|
||||||
|
|
||||||
func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
roomID string,
|
roomID string,
|
||||||
r types.Range,
|
r types.Range,
|
||||||
stateFilter *gomatrixserverlib.StateFilter,
|
stateFilter *gomatrixserverlib.StateFilter,
|
||||||
|
@ -460,7 +428,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
jr = types.NewJoinResponse()
|
jr = types.NewJoinResponse()
|
||||||
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
||||||
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
|
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
|
||||||
recentStreamEvents, limited, err := p.DB.RecentEvents(
|
recentStreamEvents, limited, err := snapshot.RecentEvents(
|
||||||
ctx, roomID, r, eventFilter, true, true,
|
ctx, roomID, r, eventFilter, true, true,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -484,7 +452,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEvents, err := p.DB.CurrentState(ctx, roomID, stateFilter, excludingEventIDs)
|
stateEvents, err := snapshot.CurrentState(ctx, roomID, stateFilter, excludingEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -494,7 +462,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
var prevBatch *types.TopologyToken
|
var prevBatch *types.TopologyToken
|
||||||
if len(recentStreamEvents) > 0 {
|
if len(recentStreamEvents) > 0 {
|
||||||
var backwardTopologyPos, backwardStreamPos types.StreamPosition
|
var backwardTopologyPos, backwardStreamPos types.StreamPosition
|
||||||
backwardTopologyPos, backwardStreamPos, err = p.DB.PositionInTopology(ctx, recentStreamEvents[0].EventID())
|
backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, recentStreamEvents[0].EventID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -505,18 +473,18 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
prevBatch.Decrement()
|
prevBatch.Decrement()
|
||||||
}
|
}
|
||||||
|
|
||||||
p.addRoomSummary(ctx, jr, roomID, device.UserID, r.From)
|
p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From)
|
||||||
|
|
||||||
// We don't include a device here as we don't need to send down
|
// We don't include a device here as we don't need to send down
|
||||||
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
|
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
|
||||||
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
|
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
|
||||||
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
|
recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents)
|
||||||
stateEvents = removeDuplicates(stateEvents, recentEvents)
|
stateEvents = removeDuplicates(stateEvents, recentEvents)
|
||||||
|
|
||||||
events := recentEvents
|
events := recentEvents
|
||||||
// Only apply history visibility checks if the response is for joined rooms
|
// Only apply history visibility checks if the response is for joined rooms
|
||||||
if !isPeek {
|
if !isPeek {
|
||||||
events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
|
events, err = applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("unable to apply history visibility filter")
|
logrus.WithError(err).Error("unable to apply history visibility filter")
|
||||||
}
|
}
|
||||||
|
@ -530,7 +498,8 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
|
stateEvents, err = p.lazyLoadMembers(
|
||||||
|
ctx, snapshot, roomID,
|
||||||
false, limited, stateFilter,
|
false, limited, stateFilter,
|
||||||
device, recentEvents, stateEvents,
|
device, recentEvents, stateEvents,
|
||||||
)
|
)
|
||||||
|
@ -549,7 +518,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PDUStreamProvider) lazyLoadMembers(
|
func (p *PDUStreamProvider) lazyLoadMembers(
|
||||||
ctx context.Context, roomID string,
|
ctx context.Context, snapshot storage.DatabaseTransaction, roomID string,
|
||||||
incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter,
|
incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter,
|
||||||
device *userapi.Device,
|
device *userapi.Device,
|
||||||
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
|
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
|
||||||
|
@ -598,7 +567,7 @@ func (p *PDUStreamProvider) lazyLoadMembers(
|
||||||
filter.Limit = stateFilter.Limit
|
filter.Limit = stateFilter.Limit
|
||||||
filter.Senders = &wantUsers
|
filter.Senders = &wantUsers
|
||||||
filter.Types = &[]string{gomatrixserverlib.MRoomMember}
|
filter.Types = &[]string{gomatrixserverlib.MRoomMember}
|
||||||
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter)
|
memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return stateEvents, err
|
return stateEvents, err
|
||||||
}
|
}
|
||||||
|
@ -612,8 +581,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
|
||||||
|
|
||||||
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
|
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
|
||||||
// the syncreq itself for further use in streams.
|
// the syncreq itself for further use in streams.
|
||||||
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
|
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
|
||||||
ignores, err := p.DB.IgnoresForUser(ctx, req.Device.UserID)
|
ignores, err := snapshot.IgnoresForUser(ctx, req.Device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -23,20 +23,26 @@ import (
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PresenceStreamProvider struct {
|
type PresenceStreamProvider struct {
|
||||||
StreamProvider
|
DefaultStreamProvider
|
||||||
// cache contains previously sent presence updates to avoid unneeded updates
|
// cache contains previously sent presence updates to avoid unneeded updates
|
||||||
cache sync.Map
|
cache sync.Map
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PresenceStreamProvider) Setup() {
|
func (p *PresenceStreamProvider) Setup(
|
||||||
p.StreamProvider.Setup()
|
ctx context.Context, snapshot storage.DatabaseTransaction,
|
||||||
|
) {
|
||||||
|
p.DefaultStreamProvider.Setup(ctx, snapshot)
|
||||||
|
|
||||||
id, err := p.DB.MaxStreamPositionForPresence(context.Background())
|
p.latestMutex.Lock()
|
||||||
|
defer p.latestMutex.Unlock()
|
||||||
|
|
||||||
|
id, err := snapshot.MaxStreamPositionForPresence(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -45,18 +51,20 @@ func (p *PresenceStreamProvider) Setup() {
|
||||||
|
|
||||||
func (p *PresenceStreamProvider) CompleteSync(
|
func (p *PresenceStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
|
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PresenceStreamProvider) IncrementalSync(
|
func (p *PresenceStreamProvider) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
// We pull out a larger number than the filter asks for, since we're filtering out events later
|
// We pull out a larger number than the filter asks for, since we're filtering out events later
|
||||||
presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000})
|
presences, err := snapshot.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.PresenceAfter failed")
|
req.Log.WithError(err).Error("p.DB.PresenceAfter failed")
|
||||||
return from
|
return from
|
||||||
|
@ -84,7 +92,7 @@ func (p *PresenceStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
// Bear in mind that this might return nil, but at least populating
|
// Bear in mind that this might return nil, but at least populating
|
||||||
// a nil means that there's a map entry so we won't repeat this call.
|
// a nil means that there's a map entry so we won't repeat this call.
|
||||||
presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i])
|
presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("unable to query presence for user")
|
req.Log.WithError(err).Error("unable to query presence for user")
|
||||||
return from
|
return from
|
||||||
|
|
|
@ -4,18 +4,24 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ReceiptStreamProvider struct {
|
type ReceiptStreamProvider struct {
|
||||||
StreamProvider
|
DefaultStreamProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ReceiptStreamProvider) Setup() {
|
func (p *ReceiptStreamProvider) Setup(
|
||||||
p.StreamProvider.Setup()
|
ctx context.Context, snapshot storage.DatabaseTransaction,
|
||||||
|
) {
|
||||||
|
p.DefaultStreamProvider.Setup(ctx, snapshot)
|
||||||
|
|
||||||
id, err := p.DB.MaxStreamPositionForReceipts(context.Background())
|
p.latestMutex.Lock()
|
||||||
|
defer p.latestMutex.Unlock()
|
||||||
|
|
||||||
|
id, err := snapshot.MaxStreamPositionForReceipts(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -24,13 +30,15 @@ func (p *ReceiptStreamProvider) Setup() {
|
||||||
|
|
||||||
func (p *ReceiptStreamProvider) CompleteSync(
|
func (p *ReceiptStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
|
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ReceiptStreamProvider) IncrementalSync(
|
func (p *ReceiptStreamProvider) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
|
@ -41,7 +49,7 @@ func (p *ReceiptStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lastPos, receipts, err := p.DB.RoomReceiptsAfter(ctx, joinedRooms, from)
|
lastPos, receipts, err := snapshot.RoomReceiptsAfter(ctx, joinedRooms, from)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.RoomReceiptsAfter failed")
|
req.Log.WithError(err).Error("p.DB.RoomReceiptsAfter failed")
|
||||||
return from
|
return from
|
||||||
|
|
|
@ -3,17 +3,23 @@ package streams
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SendToDeviceStreamProvider struct {
|
type SendToDeviceStreamProvider struct {
|
||||||
StreamProvider
|
DefaultStreamProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SendToDeviceStreamProvider) Setup() {
|
func (p *SendToDeviceStreamProvider) Setup(
|
||||||
p.StreamProvider.Setup()
|
ctx context.Context, snapshot storage.DatabaseTransaction,
|
||||||
|
) {
|
||||||
|
p.DefaultStreamProvider.Setup(ctx, snapshot)
|
||||||
|
|
||||||
id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background())
|
p.latestMutex.Lock()
|
||||||
|
defer p.latestMutex.Unlock()
|
||||||
|
|
||||||
|
id, err := snapshot.MaxStreamPositionForSendToDeviceMessages(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -22,18 +28,20 @@ func (p *SendToDeviceStreamProvider) Setup() {
|
||||||
|
|
||||||
func (p *SendToDeviceStreamProvider) CompleteSync(
|
func (p *SendToDeviceStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
|
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SendToDeviceStreamProvider) IncrementalSync(
|
func (p *SendToDeviceStreamProvider) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
// See if we have any new tasks to do for the send-to-device messaging.
|
// See if we have any new tasks to do for the send-to-device messaging.
|
||||||
lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
|
lastPos, events, err := snapshot.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
|
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
|
||||||
return from
|
return from
|
||||||
|
|
|
@ -5,24 +5,27 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TypingStreamProvider struct {
|
type TypingStreamProvider struct {
|
||||||
StreamProvider
|
DefaultStreamProvider
|
||||||
EDUCache *caching.EDUCache
|
EDUCache *caching.EDUCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TypingStreamProvider) CompleteSync(
|
func (p *TypingStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
|
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TypingStreamProvider) IncrementalSync(
|
func (p *TypingStreamProvider) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
snapshot storage.DatabaseTransaction,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
|
|
28
syncapi/streams/streamprovider.go
Normal file
28
syncapi/streams/streamprovider.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package streams
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StreamProvider interface {
|
||||||
|
Setup(ctx context.Context, snapshot storage.DatabaseTransaction)
|
||||||
|
|
||||||
|
// Advance will update the latest position of the stream based on
|
||||||
|
// an update and will wake callers waiting on StreamNotifyAfter.
|
||||||
|
Advance(latest types.StreamPosition)
|
||||||
|
|
||||||
|
// CompleteSync will update the response to include all updates as needed
|
||||||
|
// for a complete sync. It will always return immediately.
|
||||||
|
CompleteSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest) types.StreamPosition
|
||||||
|
|
||||||
|
// IncrementalSync will update the response to include all updates between
|
||||||
|
// the from and to sync positions. It will always return immediately,
|
||||||
|
// making no changes if the range contains no updates.
|
||||||
|
IncrementalSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition) types.StreamPosition
|
||||||
|
|
||||||
|
// LatestPosition returns the latest stream position for this stream.
|
||||||
|
LatestPosition(ctx context.Context) types.StreamPosition
|
||||||
|
}
|
|
@ -13,15 +13,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Streams struct {
|
type Streams struct {
|
||||||
PDUStreamProvider types.StreamProvider
|
PDUStreamProvider StreamProvider
|
||||||
TypingStreamProvider types.StreamProvider
|
TypingStreamProvider StreamProvider
|
||||||
ReceiptStreamProvider types.StreamProvider
|
ReceiptStreamProvider StreamProvider
|
||||||
InviteStreamProvider types.StreamProvider
|
InviteStreamProvider StreamProvider
|
||||||
SendToDeviceStreamProvider types.StreamProvider
|
SendToDeviceStreamProvider StreamProvider
|
||||||
AccountDataStreamProvider types.StreamProvider
|
AccountDataStreamProvider StreamProvider
|
||||||
DeviceListStreamProvider types.StreamProvider
|
DeviceListStreamProvider StreamProvider
|
||||||
NotificationDataStreamProvider types.StreamProvider
|
NotificationDataStreamProvider StreamProvider
|
||||||
PresenceStreamProvider types.StreamProvider
|
PresenceStreamProvider StreamProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSyncStreamProviders(
|
func NewSyncStreamProviders(
|
||||||
|
@ -31,51 +31,58 @@ func NewSyncStreamProviders(
|
||||||
) *Streams {
|
) *Streams {
|
||||||
streams := &Streams{
|
streams := &Streams{
|
||||||
PDUStreamProvider: &PDUStreamProvider{
|
PDUStreamProvider: &PDUStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
lazyLoadCache: lazyLoadCache,
|
lazyLoadCache: lazyLoadCache,
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
},
|
},
|
||||||
TypingStreamProvider: &TypingStreamProvider{
|
TypingStreamProvider: &TypingStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
EDUCache: eduCache,
|
EDUCache: eduCache,
|
||||||
},
|
},
|
||||||
ReceiptStreamProvider: &ReceiptStreamProvider{
|
ReceiptStreamProvider: &ReceiptStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
},
|
},
|
||||||
InviteStreamProvider: &InviteStreamProvider{
|
InviteStreamProvider: &InviteStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
},
|
},
|
||||||
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
|
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
},
|
},
|
||||||
AccountDataStreamProvider: &AccountDataStreamProvider{
|
AccountDataStreamProvider: &AccountDataStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
userAPI: userAPI,
|
userAPI: userAPI,
|
||||||
},
|
},
|
||||||
NotificationDataStreamProvider: &NotificationDataStreamProvider{
|
NotificationDataStreamProvider: &NotificationDataStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
},
|
},
|
||||||
DeviceListStreamProvider: &DeviceListStreamProvider{
|
DeviceListStreamProvider: &DeviceListStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
keyAPI: keyAPI,
|
keyAPI: keyAPI,
|
||||||
},
|
},
|
||||||
PresenceStreamProvider: &PresenceStreamProvider{
|
PresenceStreamProvider: &PresenceStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
streams.PDUStreamProvider.Setup()
|
ctx := context.TODO()
|
||||||
streams.TypingStreamProvider.Setup()
|
snapshot, err := d.NewDatabaseSnapshot(ctx)
|
||||||
streams.ReceiptStreamProvider.Setup()
|
if err != nil {
|
||||||
streams.InviteStreamProvider.Setup()
|
panic(err)
|
||||||
streams.SendToDeviceStreamProvider.Setup()
|
}
|
||||||
streams.AccountDataStreamProvider.Setup()
|
defer snapshot.Rollback() // nolint:errcheck
|
||||||
streams.NotificationDataStreamProvider.Setup()
|
|
||||||
streams.DeviceListStreamProvider.Setup()
|
streams.PDUStreamProvider.Setup(ctx, snapshot)
|
||||||
streams.PresenceStreamProvider.Setup()
|
streams.TypingStreamProvider.Setup(ctx, snapshot)
|
||||||
|
streams.ReceiptStreamProvider.Setup(ctx, snapshot)
|
||||||
|
streams.InviteStreamProvider.Setup(ctx, snapshot)
|
||||||
|
streams.SendToDeviceStreamProvider.Setup(ctx, snapshot)
|
||||||
|
streams.AccountDataStreamProvider.Setup(ctx, snapshot)
|
||||||
|
streams.NotificationDataStreamProvider.Setup(ctx, snapshot)
|
||||||
|
streams.DeviceListStreamProvider.Setup(ctx, snapshot)
|
||||||
|
streams.PresenceStreamProvider.Setup(ctx, snapshot)
|
||||||
|
|
||||||
return streams
|
return streams
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,16 +8,18 @@ import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StreamProvider struct {
|
type DefaultStreamProvider struct {
|
||||||
DB storage.Database
|
DB storage.Database
|
||||||
latest types.StreamPosition
|
latest types.StreamPosition
|
||||||
latestMutex sync.RWMutex
|
latestMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *StreamProvider) Setup() {
|
func (p *DefaultStreamProvider) Setup(
|
||||||
|
ctx context.Context, snapshot storage.DatabaseTransaction,
|
||||||
|
) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *StreamProvider) Advance(
|
func (p *DefaultStreamProvider) Advance(
|
||||||
latest types.StreamPosition,
|
latest types.StreamPosition,
|
||||||
) {
|
) {
|
||||||
p.latestMutex.Lock()
|
p.latestMutex.Lock()
|
||||||
|
@ -28,7 +30,7 @@ func (p *StreamProvider) Advance(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *StreamProvider) LatestPosition(
|
func (p *DefaultStreamProvider) LatestPosition(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
p.latestMutex.RLock()
|
p.latestMutex.RLock()
|
||||||
|
|
|
@ -305,6 +305,13 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
|
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
snapshot, err := rp.db.NewDatabaseSnapshot(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to acquire database snapshot for sync request")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
defer snapshot.Rollback() // nolint:errcheck
|
||||||
|
|
||||||
if syncReq.Since.IsEmpty() {
|
if syncReq.Since.IsEmpty() {
|
||||||
// Complete sync
|
// Complete sync
|
||||||
syncReq.Response.NextBatch = types.StreamingToken{
|
syncReq.Response.NextBatch = types.StreamingToken{
|
||||||
|
@ -312,70 +319,70 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
// might advance while processing other streams, resulting in flakey
|
// might advance while processing other streams, resulting in flakey
|
||||||
// tests.
|
// tests.
|
||||||
DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync(
|
DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
),
|
),
|
||||||
PDUPosition: rp.streams.PDUStreamProvider.CompleteSync(
|
PDUPosition: rp.streams.PDUStreamProvider.CompleteSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
),
|
),
|
||||||
TypingPosition: rp.streams.TypingStreamProvider.CompleteSync(
|
TypingPosition: rp.streams.TypingStreamProvider.CompleteSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
),
|
),
|
||||||
ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync(
|
ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
),
|
),
|
||||||
InvitePosition: rp.streams.InviteStreamProvider.CompleteSync(
|
InvitePosition: rp.streams.InviteStreamProvider.CompleteSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
),
|
),
|
||||||
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync(
|
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
),
|
),
|
||||||
AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync(
|
AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
),
|
),
|
||||||
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync(
|
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
),
|
),
|
||||||
PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync(
|
PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Incremental sync
|
// Incremental sync
|
||||||
syncReq.Response.NextBatch = types.StreamingToken{
|
syncReq.Response.NextBatch = types.StreamingToken{
|
||||||
PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync(
|
PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
syncReq.Since.PDUPosition, currentPos.PDUPosition,
|
syncReq.Since.PDUPosition, currentPos.PDUPosition,
|
||||||
),
|
),
|
||||||
TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync(
|
TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
syncReq.Since.TypingPosition, currentPos.TypingPosition,
|
syncReq.Since.TypingPosition, currentPos.TypingPosition,
|
||||||
),
|
),
|
||||||
ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync(
|
ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition,
|
syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition,
|
||||||
),
|
),
|
||||||
InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync(
|
InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
syncReq.Since.InvitePosition, currentPos.InvitePosition,
|
syncReq.Since.InvitePosition, currentPos.InvitePosition,
|
||||||
),
|
),
|
||||||
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync(
|
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition,
|
syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition,
|
||||||
),
|
),
|
||||||
AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync(
|
AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition,
|
syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition,
|
||||||
),
|
),
|
||||||
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync(
|
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition,
|
syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition,
|
||||||
),
|
),
|
||||||
DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync(
|
DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition,
|
syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition,
|
||||||
),
|
),
|
||||||
PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync(
|
PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync(
|
||||||
syncReq.Context, syncReq,
|
syncReq.Context, snapshot, syncReq,
|
||||||
syncReq.Since.PresencePosition, currentPos.PresencePosition,
|
syncReq.Since.PresencePosition, currentPos.PresencePosition,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -437,9 +444,15 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed")
|
util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition)
|
snapshot, err := rp.db.NewDatabaseSnapshot(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to acquire database snapshot for key change")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
defer snapshot.Rollback() // nolint:errcheck
|
||||||
|
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition)
|
||||||
_, _, err = internal.DeviceListCatchup(
|
_, _, err = internal.DeviceListCatchup(
|
||||||
req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
|
req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
|
||||||
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
|
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -41,23 +41,3 @@ func (r *SyncRequest) IsRoomPresent(roomID string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamProvider interface {
|
|
||||||
Setup()
|
|
||||||
|
|
||||||
// Advance will update the latest position of the stream based on
|
|
||||||
// an update and will wake callers waiting on StreamNotifyAfter.
|
|
||||||
Advance(latest StreamPosition)
|
|
||||||
|
|
||||||
// CompleteSync will update the response to include all updates as needed
|
|
||||||
// for a complete sync. It will always return immediately.
|
|
||||||
CompleteSync(ctx context.Context, req *SyncRequest) StreamPosition
|
|
||||||
|
|
||||||
// IncrementalSync will update the response to include all updates between
|
|
||||||
// the from and to sync positions. It will always return immediately,
|
|
||||||
// making no changes if the range contains no updates.
|
|
||||||
IncrementalSync(ctx context.Context, req *SyncRequest, from, to StreamPosition) StreamPosition
|
|
||||||
|
|
||||||
// LatestPosition returns the latest stream position for this stream.
|
|
||||||
LatestPosition(ctx context.Context) StreamPosition
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue