Rename to transaction instead of snapshot

This commit is contained in:
Neil Alexander 2022-09-30 10:27:05 +01:00
parent ceb530466f
commit 395b3e67c0
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
20 changed files with 121 additions and 118 deletions

View file

@ -100,7 +100,7 @@ func (ev eventVisibility) allowed() (allowed bool) {
// Returns the filtered events and an error, if any. // Returns the filtered events and an error, if any.
func ApplyHistoryVisibilityFilter( func ApplyHistoryVisibilityFilter(
ctx context.Context, ctx context.Context,
syncDB storage.DatabaseSnapshot, syncDB storage.DatabaseTransaction,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent, events []*gomatrixserverlib.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{}, alwaysIncludeEventIDs map[string]struct{},

View file

@ -194,7 +194,7 @@ func Context(
// by combining the events before and after the context event. Returns the filtered events, // by combining the events before and after the context event. Returns the filtered events,
// and an error, if any. // and an error, if any.
func applyHistoryVisibilityOnContextEvents( func applyHistoryVisibilityOnContextEvents(
ctx context.Context, snapshot storage.DatabaseSnapshot, rsAPI roomserver.SyncRoomserverAPI, ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent, eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
userID string, userID string,
) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) { ) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
@ -228,7 +228,7 @@ func applyHistoryVisibilityOnContextEvents(
return filteredBefore, filteredAfter, nil return filteredBefore, filteredAfter, nil
} }
func getStartEnd(ctx context.Context, snapshot storage.DatabaseSnapshot, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
if len(startEvents) > 0 { if len(startEvents) > 0 {
start, err = snapshot.EventPositionInTopology(ctx, startEvents[0].EventID()) start, err = snapshot.EventPositionInTopology(ctx, startEvents[0].EventID())
if err != nil { if err != nil {

View file

@ -39,7 +39,7 @@ import (
type messagesReq struct { type messagesReq struct {
ctx context.Context ctx context.Context
db storage.Database db storage.Database
snapshot storage.DatabaseSnapshot snapshot storage.DatabaseTransaction
rsAPI api.SyncRoomserverAPI rsAPI api.SyncRoomserverAPI
cfg *config.SyncAPI cfg *config.SyncAPI
roomID string roomID string
@ -71,7 +71,7 @@ func OnIncomingMessagesRequest(
) util.JSONResponse { ) util.JSONResponse {
var err error var err error
snapshot, err := db.NewDatabaseWritable(req.Context()) snapshot, err := db.NewDatabaseTransaction(req.Context())
if err != nil { if err != nil {
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -247,7 +247,7 @@ func OnIncomingMessagesRequest(
// LazyLoadMembers enabled. // LazyLoadMembers enabled.
func (m *messagesResp) applyLazyLoadMembers( func (m *messagesResp) applyLazyLoadMembers(
ctx context.Context, ctx context.Context,
db storage.DatabaseSnapshot, db storage.DatabaseTransaction,
roomID string, roomID string,
device *userapi.Device, device *userapi.Device,
lazyLoad bool, lazyLoad bool,
@ -561,7 +561,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
// Returns an error if there was an issue with retrieving the latest position // Returns an error if there was an issue with retrieving the latest position
// from the database // from the database
func setToDefault( func setToDefault(
ctx context.Context, snapshot storage.DatabaseSnapshot, backwardOrdering bool, ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool,
roomID string, roomID string,
) (to types.TopologyToken, err error) { ) (to types.TopologyToken, err error) {
if backwardOrdering { if backwardOrdering {

View file

@ -258,7 +258,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
// contextEvents returns the events around a given eventID // contextEvents returns the events around a given eventID
func contextEvents( func contextEvents(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
roomFilter *gomatrixserverlib.RoomEventFilter, roomFilter *gomatrixserverlib.RoomEventFilter,
searchReq SearchRequest, searchReq SearchRequest,

View file

@ -26,7 +26,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
type DatabaseSnapshot interface { type DatabaseTransaction interface {
SharedUsers SharedUsers
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
@ -111,8 +111,8 @@ type Database interface {
Presence Presence
Notifications Notifications
NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseSnapshot, error) NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error)
NewDatabaseWritable(ctx context.Context) (*shared.DatabaseSnapshot, error) NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error)
// Events lookups a list of event by their event ID. // Events lookups a list of event by their event ID.
// Returns a list of events matching the requested IDs found in the database. // Returns a list of events matching the requested IDs found in the database.

View file

@ -55,31 +55,34 @@ type Database struct {
Presence tables.Presence Presence tables.Presence
} }
func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseSnapshot, error) { func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) {
txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ return d.NewDatabaseTransaction(ctx) // TODO: revert
// Set the isolation level so that we see a snapshot of the database. /*
// In PostgreSQL repeatable read transactions will see a snapshot taken txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{
// at the first query, and since the transaction is read-only it can't // Set the isolation level so that we see a snapshot of the database.
// run into any serialisation errors. // In PostgreSQL repeatable read transactions will see a snapshot taken
// https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ // at the first query, and since the transaction is read-only it can't
Isolation: sql.LevelRepeatableRead, // run into any serialisation errors.
ReadOnly: true, // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
}) Isolation: sql.LevelRepeatableRead,
if err != nil { ReadOnly: true,
return nil, err })
} if err != nil {
return &DatabaseSnapshot{ return nil, err
Database: d, }
txn: txn, return &DatabaseTransaction{
}, nil Database: d,
txn: txn,
}, nil
*/
} }
func (d *Database) NewDatabaseWritable(ctx context.Context) (*DatabaseSnapshot, error) { func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) {
txn, err := d.DB.BeginTx(ctx, nil) txn, err := d.DB.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &DatabaseSnapshot{ return &DatabaseTransaction{
Database: d, Database: d,
txn: txn, txn: txn,
}, nil }, nil

View file

@ -11,26 +11,26 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type DatabaseSnapshot struct { type DatabaseTransaction struct {
*Database *Database
txn *sql.Tx txn *sql.Tx
} }
func (d *DatabaseSnapshot) Commit() error { func (d *DatabaseTransaction) Commit() error {
if d.txn == nil { if d.txn == nil {
return nil return nil
} }
return d.txn.Commit() return d.txn.Commit()
} }
func (d *DatabaseSnapshot) Rollback() error { func (d *DatabaseTransaction) Rollback() error {
if d.txn == nil { if d.txn == nil {
return nil return nil
} }
return d.txn.Rollback() return d.txn.Rollback()
} }
func (d *DatabaseSnapshot) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseTransaction) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) {
id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn) id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err) return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err)
@ -38,7 +38,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForPDUs(ctx context.Context) (types.
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *DatabaseSnapshot) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseTransaction) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) {
id, err := d.Receipts.SelectMaxReceiptID(ctx, d.txn) id, err := d.Receipts.SelectMaxReceiptID(ctx, d.txn)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err) return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err)
@ -46,7 +46,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForReceipts(ctx context.Context) (ty
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *DatabaseSnapshot) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseTransaction) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) {
id, err := d.Invites.SelectMaxInviteID(ctx, d.txn) id, err := d.Invites.SelectMaxInviteID(ctx, d.txn)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err) return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err)
@ -54,7 +54,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForInvites(ctx context.Context) (typ
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *DatabaseSnapshot) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseTransaction) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, d.txn) id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, d.txn)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err) return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
@ -62,7 +62,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForSendToDeviceMessages(ctx context.
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *DatabaseSnapshot) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseTransaction) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.AccountData.SelectMaxAccountDataID(ctx, d.txn) id, err := d.AccountData.SelectMaxAccountDataID(ctx, d.txn)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err) return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err)
@ -70,7 +70,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForAccountData(ctx context.Context)
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *DatabaseSnapshot) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseTransaction) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.NotificationData.SelectMaxID(ctx, d.txn) id, err := d.NotificationData.SelectMaxID(ctx, d.txn)
if err != nil { if err != nil {
return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
@ -78,39 +78,39 @@ func (d *DatabaseSnapshot) MaxStreamPositionForNotificationData(ctx context.Cont
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *DatabaseSnapshot) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs) return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs)
} }
func (d *DatabaseSnapshot) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) { func (d *DatabaseTransaction) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) {
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.txn, userID, membership) return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.txn, userID, membership)
} }
func (d *DatabaseSnapshot) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) { func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) {
return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos) return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
} }
func (d *DatabaseSnapshot) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships) return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships)
} }
func (d *DatabaseSnapshot) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
} }
func (d *DatabaseSnapshot) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID) return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
} }
func (d *DatabaseSnapshot) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { func (d *DatabaseTransaction) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r) return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r)
} }
func (d *DatabaseSnapshot) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) { func (d *DatabaseTransaction) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) {
return d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, deviceID, r) return d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, deviceID, r)
} }
func (d *DatabaseSnapshot) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { func (d *DatabaseTransaction) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos) return d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
} }
@ -119,7 +119,7 @@ func (d *DatabaseSnapshot) RoomReceiptsAfter(ctx context.Context, roomIDs []stri
// If an event is not found in the database then it will be omitted from the list. // If an event is not found in the database then it will be omitted from the list.
// Returns an error if there was a problem talking with the database. // Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events. // Does not include any transaction IDs in the returned events.
func (d *DatabaseSnapshot) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false) streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -130,29 +130,29 @@ func (d *DatabaseSnapshot) Events(ctx context.Context, eventIDs []string) ([]*go
return d.StreamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(nil, streamEvents), nil
} }
func (d *DatabaseSnapshot) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsers(ctx, d.txn) return d.CurrentRoomState.SelectJoinedUsers(ctx, d.txn)
} }
func (d *DatabaseSnapshot) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { func (d *DatabaseTransaction) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.txn, roomIDs) return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.txn, roomIDs)
} }
func (d *DatabaseSnapshot) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { func (d *DatabaseTransaction) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
return d.Peeks.SelectPeekingDevices(ctx, d.txn) return d.Peeks.SelectPeekingDevices(ctx, d.txn)
} }
func (d *DatabaseSnapshot) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) { func (d *DatabaseTransaction) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
return d.CurrentRoomState.SelectSharedUsers(ctx, d.txn, userID, otherUserIDs) return d.CurrentRoomState.SelectSharedUsers(ctx, d.txn, userID, otherUserIDs)
} }
func (d *DatabaseSnapshot) GetStateEvent( func (d *DatabaseTransaction) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey) return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey)
} }
func (d *DatabaseSnapshot) GetStateEventsForRoom( func (d *DatabaseTransaction) GetStateEventsForRoom(
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { ) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) {
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil)
@ -164,14 +164,14 @@ func (d *DatabaseSnapshot) GetStateEventsForRoom(
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map // If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error // If there was an issue with the retrieval, returns an error
func (d *DatabaseSnapshot) GetAccountDataInRange( func (d *DatabaseTransaction) GetAccountDataInRange(
ctx context.Context, userID string, r types.Range, ctx context.Context, userID string, r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter, accountDataFilterPart *gomatrixserverlib.EventFilter,
) (map[string][]string, types.StreamPosition, error) { ) (map[string][]string, types.StreamPosition, error) {
return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart) return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart)
} }
func (d *DatabaseSnapshot) GetEventsInTopologicalRange( func (d *DatabaseTransaction) GetEventsInTopologicalRange(
ctx context.Context, ctx context.Context,
from, to *types.TopologyToken, from, to *types.TopologyToken,
roomID string, roomID string,
@ -207,13 +207,13 @@ func (d *DatabaseSnapshot) GetEventsInTopologicalRange(
return return
} }
func (d *DatabaseSnapshot) BackwardExtremitiesForRoom( func (d *DatabaseTransaction) BackwardExtremitiesForRoom(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (backwardExtremities map[string][]string, err error) { ) (backwardExtremities map[string][]string, err error) {
return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID) return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID)
} }
func (d *DatabaseSnapshot) MaxTopologicalPosition( func (d *DatabaseTransaction) MaxTopologicalPosition(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
@ -223,7 +223,7 @@ func (d *DatabaseSnapshot) MaxTopologicalPosition(
return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
} }
func (d *DatabaseSnapshot) EventPositionInTopology( func (d *DatabaseTransaction) EventPositionInTopology(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.txn, eventID) depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
@ -233,7 +233,7 @@ func (d *DatabaseSnapshot) EventPositionInTopology(
return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
} }
func (d *DatabaseSnapshot) StreamToTopologicalPosition( func (d *DatabaseTransaction) StreamToTopologicalPosition(
ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.txn, roomID, streamPos, backwardOrdering) topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.txn, roomID, streamPos, backwardOrdering)
@ -255,7 +255,7 @@ func (d *DatabaseSnapshot) StreamToTopologicalPosition(
// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the // GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
// oldest event in the room's topology. // oldest event in the room's topology.
func (d *DatabaseSnapshot) GetBackwardTopologyPos( func (d *DatabaseTransaction) GetBackwardTopologyPos(
ctx context.Context, ctx context.Context,
events []types.StreamEvent, events []types.StreamEvent,
) (types.TopologyToken, error) { ) (types.TopologyToken, error) {
@ -276,7 +276,7 @@ func (d *DatabaseSnapshot) GetBackwardTopologyPos(
// exclusive of oldPos, inclusive of newPos, for the rooms in which // exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events. // the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it. // A list of joined room IDs is also returned in case the caller needs it.
func (d *DatabaseSnapshot) GetStateDeltas( func (d *DatabaseTransaction) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
@ -403,7 +403,7 @@ func (d *DatabaseSnapshot) GetStateDeltas(
// requests with full_state=true. // requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get // Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms. // updates for other rooms.
func (d *DatabaseSnapshot) GetStateDeltasForFullStateSync( func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
@ -513,7 +513,7 @@ func (d *DatabaseSnapshot) GetStateDeltasForFullStateSync(
return result, joinedRoomIDs, nil return result, joinedRoomIDs, nil
} }
func (d *DatabaseSnapshot) currentStateStreamEventsForRoom( func (d *DatabaseTransaction) currentStateStreamEventsForRoom(
ctx context.Context, roomID string, ctx context.Context, roomID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
@ -528,7 +528,7 @@ func (d *DatabaseSnapshot) currentStateStreamEventsForRoom(
return s, nil return s, nil
} }
func (d *DatabaseSnapshot) SendToDeviceUpdatesForSync( func (d *DatabaseTransaction) SendToDeviceUpdatesForSync(
ctx context.Context, ctx context.Context,
userID, deviceID string, userID, deviceID string,
from, to types.StreamPosition, from, to types.StreamPosition,
@ -545,12 +545,12 @@ func (d *DatabaseSnapshot) SendToDeviceUpdatesForSync(
return lastPos, events, nil return lastPos, events, nil
} }
func (d *DatabaseSnapshot) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) { func (d *DatabaseTransaction) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) {
_, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos) _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
return receipts, err return receipts, err
} }
func (d *DatabaseSnapshot) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) { func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
roomIDs := make([]string, 0, len(rooms)) roomIDs := make([]string, 0, len(rooms))
for roomID, membership := range rooms { for roomID, membership := range rooms {
if membership != gomatrixserverlib.Join { if membership != gomatrixserverlib.Join {
@ -561,14 +561,14 @@ func (d *DatabaseSnapshot) GetUserUnreadNotificationCountsForRooms(ctx context.C
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
} }
func (d *DatabaseSnapshot) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, d.txn, userID) return d.Presence.GetPresenceForUser(ctx, d.txn, userID)
} }
func (d *DatabaseSnapshot) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter) return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter)
} }
func (d *DatabaseSnapshot) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
return d.Presence.GetMaxPresenceID(ctx, d.txn) return d.Presence.GetMaxPresenceID(ctx, d.txn)
} }

View file

@ -49,15 +49,15 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
return &d, nil return &d, nil
} }
func (d *SyncServerDatasource) NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseSnapshot, error) { func (d *SyncServerDatasource) NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error) {
return &shared.DatabaseSnapshot{ return &shared.DatabaseTransaction{
Database: &d.Database, Database: &d.Database,
// not setting a transaction because SQLite doesn't support it // not setting a transaction because SQLite doesn't support it
}, nil }, nil
} }
func (d *SyncServerDatasource) NewDatabaseWritable(ctx context.Context) (*shared.DatabaseSnapshot, error) { func (d *SyncServerDatasource) NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error) {
return &shared.DatabaseSnapshot{ return &shared.DatabaseTransaction{
Database: &d.Database, Database: &d.Database,
// not setting a transaction because SQLite doesn't support it // not setting a transaction because SQLite doesn't support it
}, nil }, nil

View file

@ -60,7 +60,7 @@ func TestWriteEvents(t *testing.T) {
}) })
} }
func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.DatabaseSnapshot)) { func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.DatabaseTransaction)) {
snapshot, err := db.NewDatabaseSnapshot(ctx) snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -91,7 +91,7 @@ func TestRecentEventsPDU(t *testing.T) {
MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
var latest types.StreamPosition var latest types.StreamPosition
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) { WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
var err error var err error
if latest, err = snapshot.MaxStreamPositionForPDUs(ctx); err != nil { if latest, err = snapshot.MaxStreamPositionForPDUs(ctx); err != nil {
t.Fatal("failed to get MaxStreamPositionForPDUs: %w", err) t.Fatal("failed to get MaxStreamPositionForPDUs: %w", err)
@ -157,7 +157,7 @@ func TestRecentEventsPDU(t *testing.T) {
var gotEvents []types.StreamEvent var gotEvents []types.StreamEvent
var limited bool var limited bool
filter.Limit = tc.Limit filter.Limit = tc.Limit
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) { WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
var err error var err error
gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{ gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{
From: tc.From, From: tc.From,
@ -197,7 +197,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
events := r.Events() events := r.Events()
_ = MustWriteEvents(t, db, events) _ = MustWriteEvents(t, db, events)
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) { WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
from, err := snapshot.MaxTopologicalPosition(ctx, r.ID) from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
if err != nil { if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
@ -436,7 +436,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point there should be no messages. We haven't sent anything // At this point there should be no messages. We haven't sent anything
// yet. // yet.
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) { WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -456,7 +456,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) { WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
// At this point we should get exactly one message. We're sending the sync position // At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated // that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at. // in the database to reflect that this was the sync position we sent the message at.
@ -486,7 +486,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
return return
} }
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) { WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
// At this point we should now have no updates, because we've progressed the sync // At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again. // position. Therefore the update from before will not be sent again.
var events []types.SendToDeviceEvent var events []types.SendToDeviceEvent
@ -523,7 +523,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
lastPos = streamPos lastPos = streamPos
} }
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) { WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos) _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
if err != nil { if err != nil {
t.Fatalf("unable to get events: %v", err) t.Fatalf("unable to get events: %v", err)

View file

@ -16,7 +16,7 @@ type AccountDataStreamProvider struct {
} }
func (p *AccountDataStreamProvider) Setup( func (p *AccountDataStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseSnapshot, ctx context.Context, snapshot storage.DatabaseTransaction,
) { ) {
p.DefaultStreamProvider.Setup(ctx, snapshot) p.DefaultStreamProvider.Setup(ctx, snapshot)
@ -32,7 +32,7 @@ func (p *AccountDataStreamProvider) Setup(
func (p *AccountDataStreamProvider) CompleteSync( func (p *AccountDataStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
@ -40,7 +40,7 @@ func (p *AccountDataStreamProvider) CompleteSync(
func (p *AccountDataStreamProvider) IncrementalSync( func (p *AccountDataStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -18,7 +18,7 @@ type DeviceListStreamProvider struct {
func (p *DeviceListStreamProvider) CompleteSync( func (p *DeviceListStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.LatestPosition(ctx) return p.LatestPosition(ctx)
@ -26,7 +26,7 @@ func (p *DeviceListStreamProvider) CompleteSync(
func (p *DeviceListStreamProvider) IncrementalSync( func (p *DeviceListStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -18,7 +18,7 @@ type InviteStreamProvider struct {
} }
func (p *InviteStreamProvider) Setup( func (p *InviteStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseSnapshot, ctx context.Context, snapshot storage.DatabaseTransaction,
) { ) {
p.DefaultStreamProvider.Setup(ctx, snapshot) p.DefaultStreamProvider.Setup(ctx, snapshot)
@ -34,7 +34,7 @@ func (p *InviteStreamProvider) Setup(
func (p *InviteStreamProvider) CompleteSync( func (p *InviteStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
@ -42,7 +42,7 @@ func (p *InviteStreamProvider) CompleteSync(
func (p *InviteStreamProvider) IncrementalSync( func (p *InviteStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -12,7 +12,7 @@ type NotificationDataStreamProvider struct {
} }
func (p *NotificationDataStreamProvider) Setup( func (p *NotificationDataStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseSnapshot, ctx context.Context, snapshot storage.DatabaseTransaction,
) { ) {
p.DefaultStreamProvider.Setup(ctx, snapshot) p.DefaultStreamProvider.Setup(ctx, snapshot)
@ -28,7 +28,7 @@ func (p *NotificationDataStreamProvider) Setup(
func (p *NotificationDataStreamProvider) CompleteSync( func (p *NotificationDataStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
@ -36,7 +36,7 @@ func (p *NotificationDataStreamProvider) CompleteSync(
func (p *NotificationDataStreamProvider) IncrementalSync( func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
from, _ types.StreamPosition, from, _ types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -40,7 +40,7 @@ type PDUStreamProvider struct {
} }
func (p *PDUStreamProvider) Setup( func (p *PDUStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseSnapshot, ctx context.Context, snapshot storage.DatabaseTransaction,
) { ) {
p.DefaultStreamProvider.Setup(ctx, snapshot) p.DefaultStreamProvider.Setup(ctx, snapshot)
@ -56,7 +56,7 @@ func (p *PDUStreamProvider) Setup(
func (p *PDUStreamProvider) CompleteSync( func (p *PDUStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
from := types.StreamPosition(0) from := types.StreamPosition(0)
@ -132,7 +132,7 @@ func (p *PDUStreamProvider) CompleteSync(
func (p *PDUStreamProvider) IncrementalSync( func (p *PDUStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) (newPos types.StreamPosition) { ) (newPos types.StreamPosition) {
@ -210,7 +210,7 @@ func (p *PDUStreamProvider) IncrementalSync(
// nolint:gocyclo // nolint:gocyclo
func (p *PDUStreamProvider) addRoomDeltaToResponse( func (p *PDUStreamProvider) addRoomDeltaToResponse(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
device *userapi.Device, device *userapi.Device,
r types.Range, r types.Range,
delta types.StateDelta, delta types.StateDelta,
@ -343,7 +343,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// sure we always return the required events in the timeline. // sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter( func applyHistoryVisibilityFilter(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
rsAPI roomserverAPI.SyncRoomserverAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
roomID, userID string, roomID, userID string,
limit int, limit int,
@ -375,7 +375,7 @@ func applyHistoryVisibilityFilter(
return events, nil return events, nil
} }
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseSnapshot, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
// Work out how many members are in the room. // Work out how many members are in the room.
joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
@ -416,7 +416,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage
func (p *PDUStreamProvider) getJoinResponseForCompleteSync( func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
roomID string, roomID string,
r types.Range, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
@ -518,7 +518,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
func (p *PDUStreamProvider) lazyLoadMembers( func (p *PDUStreamProvider) lazyLoadMembers(
ctx context.Context, snapshot storage.DatabaseSnapshot, roomID string, ctx context.Context, snapshot storage.DatabaseTransaction, roomID string,
incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter, incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter,
device *userapi.Device, device *userapi.Device,
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
@ -581,7 +581,7 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// addIgnoredUsersToFilter adds ignored users to the eventfilter and // addIgnoredUsersToFilter adds ignored users to the eventfilter and
// the syncreq itself for further use in streams. // the syncreq itself for further use in streams.
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseSnapshot, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
ignores, err := snapshot.IgnoresForUser(ctx, req.Device.UserID) ignores, err := snapshot.IgnoresForUser(ctx, req.Device.UserID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {

View file

@ -35,7 +35,7 @@ type PresenceStreamProvider struct {
} }
func (p *PresenceStreamProvider) Setup( func (p *PresenceStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseSnapshot, ctx context.Context, snapshot storage.DatabaseTransaction,
) { ) {
p.DefaultStreamProvider.Setup(ctx, snapshot) p.DefaultStreamProvider.Setup(ctx, snapshot)
@ -51,7 +51,7 @@ func (p *PresenceStreamProvider) Setup(
func (p *PresenceStreamProvider) CompleteSync( func (p *PresenceStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
@ -59,7 +59,7 @@ func (p *PresenceStreamProvider) CompleteSync(
func (p *PresenceStreamProvider) IncrementalSync( func (p *PresenceStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -14,7 +14,7 @@ type ReceiptStreamProvider struct {
} }
func (p *ReceiptStreamProvider) Setup( func (p *ReceiptStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseSnapshot, ctx context.Context, snapshot storage.DatabaseTransaction,
) { ) {
p.DefaultStreamProvider.Setup(ctx, snapshot) p.DefaultStreamProvider.Setup(ctx, snapshot)
@ -30,7 +30,7 @@ func (p *ReceiptStreamProvider) Setup(
func (p *ReceiptStreamProvider) CompleteSync( func (p *ReceiptStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
@ -38,7 +38,7 @@ func (p *ReceiptStreamProvider) CompleteSync(
func (p *ReceiptStreamProvider) IncrementalSync( func (p *ReceiptStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -12,7 +12,7 @@ type SendToDeviceStreamProvider struct {
} }
func (p *SendToDeviceStreamProvider) Setup( func (p *SendToDeviceStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseSnapshot, ctx context.Context, snapshot storage.DatabaseTransaction,
) { ) {
p.DefaultStreamProvider.Setup(ctx, snapshot) p.DefaultStreamProvider.Setup(ctx, snapshot)
@ -28,7 +28,7 @@ func (p *SendToDeviceStreamProvider) Setup(
func (p *SendToDeviceStreamProvider) CompleteSync( func (p *SendToDeviceStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
@ -36,7 +36,7 @@ func (p *SendToDeviceStreamProvider) CompleteSync(
func (p *SendToDeviceStreamProvider) IncrementalSync( func (p *SendToDeviceStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -17,7 +17,7 @@ type TypingStreamProvider struct {
func (p *TypingStreamProvider) CompleteSync( func (p *TypingStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
) types.StreamPosition { ) types.StreamPosition {
return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
@ -25,7 +25,7 @@ func (p *TypingStreamProvider) CompleteSync(
func (p *TypingStreamProvider) IncrementalSync( func (p *TypingStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseSnapshot, snapshot storage.DatabaseTransaction,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {

View file

@ -8,7 +8,7 @@ import (
) )
type StreamProvider interface { type StreamProvider interface {
Setup(ctx context.Context, snapshot storage.DatabaseSnapshot) Setup(ctx context.Context, snapshot storage.DatabaseTransaction)
// Advance will update the latest position of the stream based on // Advance will update the latest position of the stream based on
// an update and will wake callers waiting on StreamNotifyAfter. // an update and will wake callers waiting on StreamNotifyAfter.
@ -16,12 +16,12 @@ type StreamProvider interface {
// CompleteSync will update the response to include all updates as needed // CompleteSync will update the response to include all updates as needed
// for a complete sync. It will always return immediately. // for a complete sync. It will always return immediately.
CompleteSync(ctx context.Context, snapshot storage.DatabaseSnapshot, req *types.SyncRequest) types.StreamPosition CompleteSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest) types.StreamPosition
// IncrementalSync will update the response to include all updates between // IncrementalSync will update the response to include all updates between
// the from and to sync positions. It will always return immediately, // the from and to sync positions. It will always return immediately,
// making no changes if the range contains no updates. // making no changes if the range contains no updates.
IncrementalSync(ctx context.Context, snapshot storage.DatabaseSnapshot, req *types.SyncRequest, from, to types.StreamPosition) types.StreamPosition IncrementalSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition) types.StreamPosition
// LatestPosition returns the latest stream position for this stream. // LatestPosition returns the latest stream position for this stream.
LatestPosition(ctx context.Context) types.StreamPosition LatestPosition(ctx context.Context) types.StreamPosition

View file

@ -15,7 +15,7 @@ type DefaultStreamProvider struct {
} }
func (p *DefaultStreamProvider) Setup( func (p *DefaultStreamProvider) Setup(
ctx context.Context, snapshot storage.DatabaseSnapshot, ctx context.Context, snapshot storage.DatabaseTransaction,
) { ) {
} }