mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-23 14:51:56 -06:00
c136a450d5
Fixes #2803 Also refactors the presence stream to not hit the database for every user, instead queries all users at once now.
691 lines
25 KiB
Go
691 lines
25 KiB
Go
package shared
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"github.com/matrix-org/gomatrixserverlib"
|
|
|
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
|
"github.com/matrix-org/dendrite/syncapi/types"
|
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
|
)
|
|
|
|
type DatabaseTransaction struct {
|
|
*Database
|
|
ctx context.Context
|
|
txn *sql.Tx
|
|
}
|
|
|
|
func (d *DatabaseTransaction) Commit() error {
|
|
if d.txn == nil {
|
|
return nil
|
|
}
|
|
return d.txn.Commit()
|
|
}
|
|
|
|
func (d *DatabaseTransaction) Rollback() error {
|
|
if d.txn == nil {
|
|
return nil
|
|
}
|
|
return d.txn.Rollback()
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) {
|
|
id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err)
|
|
}
|
|
return types.StreamPosition(id), nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) {
|
|
id, err := d.Receipts.SelectMaxReceiptID(ctx, d.txn)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err)
|
|
}
|
|
return types.StreamPosition(id), nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) {
|
|
id, err := d.Invites.SelectMaxInviteID(ctx, d.txn)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err)
|
|
}
|
|
return types.StreamPosition(id), nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
|
|
id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, d.txn)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
|
|
}
|
|
return types.StreamPosition(id), nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
|
|
id, err := d.AccountData.SelectMaxAccountDataID(ctx, d.txn)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err)
|
|
}
|
|
return types.StreamPosition(id), nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
|
|
id, err := d.NotificationData.SelectMaxID(ctx, d.txn)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
|
|
}
|
|
return types.StreamPosition(id), nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
|
return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) {
|
|
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.txn, userID, membership)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) {
|
|
return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
|
|
return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
|
|
return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
|
|
return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
|
|
return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) {
|
|
return d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, deviceID, r)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
|
|
return d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
|
|
}
|
|
|
|
// Events lookups a list of event by their event ID.
|
|
// Returns a list of events matching the requested IDs found in the database.
|
|
// If an event is not found in the database then it will be omitted from the list.
|
|
// Returns an error if there was a problem talking with the database.
|
|
// Does not include any transaction IDs in the returned events.
|
|
func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
|
streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// We don't include a device here as we only include transaction IDs in
|
|
// incremental syncs.
|
|
return d.StreamEventsToEvents(nil, streamEvents), nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
|
|
return d.CurrentRoomState.SelectJoinedUsers(ctx, d.txn)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
|
|
return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.txn, roomIDs)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
|
|
return d.Peeks.SelectPeekingDevices(ctx, d.txn)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
|
|
return d.CurrentRoomState.SelectSharedUsers(ctx, d.txn, userID, otherUserIDs)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) GetStateEvent(
|
|
ctx context.Context, roomID, evType, stateKey string,
|
|
) (*gomatrixserverlib.HeaderedEvent, error) {
|
|
return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) GetStateEventsForRoom(
|
|
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
|
|
) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) {
|
|
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil)
|
|
return
|
|
}
|
|
|
|
// GetAccountDataInRange returns all account data for a given user inserted or
|
|
// updated between two given positions
|
|
// Returns a map following the format data[roomID] = []dataTypes
|
|
// If no data is retrieved, returns an empty map
|
|
// If there was an issue with the retrieval, returns an error
|
|
func (d *DatabaseTransaction) GetAccountDataInRange(
|
|
ctx context.Context, userID string, r types.Range,
|
|
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
|
) (map[string][]string, types.StreamPosition, error) {
|
|
return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) GetEventsInTopologicalRange(
|
|
ctx context.Context,
|
|
from, to *types.TopologyToken,
|
|
roomID string,
|
|
filter *gomatrixserverlib.RoomEventFilter,
|
|
backwardOrdering bool,
|
|
) (events []types.StreamEvent, err error) {
|
|
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
|
|
if backwardOrdering {
|
|
// Backward ordering means the 'from' token has a higher depth than the 'to' token
|
|
minDepth = to.Depth
|
|
maxDepth = from.Depth
|
|
// for cases where we have say 5 events with the same depth, the TopologyToken needs to
|
|
// know which of the 5 the client has seen. This is done by using the PDU position.
|
|
// Events with the same maxDepth but less than this PDU position will be returned.
|
|
maxStreamPosForMaxDepth = from.PDUPosition
|
|
} else {
|
|
// Forward ordering means the 'from' token has a lower depth than the 'to' token.
|
|
minDepth = from.Depth
|
|
maxDepth = to.Depth
|
|
}
|
|
|
|
// Select the event IDs from the defined range.
|
|
var eIDs []string
|
|
eIDs, err = d.Topology.SelectEventIDsInRange(
|
|
ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
|
|
)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Retrieve the events' contents using their IDs.
|
|
events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true)
|
|
return
|
|
}
|
|
|
|
func (d *DatabaseTransaction) BackwardExtremitiesForRoom(
|
|
ctx context.Context, roomID string,
|
|
) (backwardExtremities map[string][]string, err error) {
|
|
return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MaxTopologicalPosition(
|
|
ctx context.Context, roomID string,
|
|
) (types.TopologyToken, error) {
|
|
depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
|
|
if err != nil {
|
|
return types.TopologyToken{}, err
|
|
}
|
|
return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) EventPositionInTopology(
|
|
ctx context.Context, eventID string,
|
|
) (types.TopologyToken, error) {
|
|
depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
|
|
if err != nil {
|
|
return types.TopologyToken{}, err
|
|
}
|
|
return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) StreamToTopologicalPosition(
|
|
ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
|
|
) (types.TopologyToken, error) {
|
|
topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.txn, roomID, streamPos, backwardOrdering)
|
|
switch {
|
|
case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward
|
|
return types.TopologyToken{PDUPosition: streamPos}, nil
|
|
case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward
|
|
topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
|
|
if err != nil {
|
|
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err)
|
|
}
|
|
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
|
|
case err != nil: // some other error happened
|
|
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err)
|
|
default:
|
|
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
|
|
}
|
|
}
|
|
|
|
// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
|
|
// oldest event in the room's topology.
|
|
func (d *DatabaseTransaction) GetBackwardTopologyPos(
|
|
ctx context.Context,
|
|
events []*gomatrixserverlib.HeaderedEvent,
|
|
) (types.TopologyToken, error) {
|
|
zeroToken := types.TopologyToken{}
|
|
if len(events) == 0 {
|
|
return zeroToken, nil
|
|
}
|
|
pos, spos, err := d.Topology.SelectPositionInTopology(ctx, d.txn, events[0].EventID())
|
|
if err != nil {
|
|
return zeroToken, err
|
|
}
|
|
tok := types.TopologyToken{Depth: pos, PDUPosition: spos}
|
|
tok.Decrement()
|
|
return tok, nil
|
|
}
|
|
|
|
// GetStateDeltas returns the state deltas between fromPos and toPos,
|
|
// exclusive of oldPos, inclusive of newPos, for the rooms in which
|
|
// the user has new membership events.
|
|
// A list of joined room IDs is also returned in case the caller needs it.
|
|
// nolint:gocyclo
|
|
func (d *DatabaseTransaction) GetStateDeltas(
|
|
ctx context.Context, device *userapi.Device,
|
|
r types.Range, userID string,
|
|
stateFilter *gomatrixserverlib.StateFilter,
|
|
) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
|
|
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
|
|
// - Get membership list changes for this user in this sync response
|
|
// - For each room which has membership list changes:
|
|
// * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO).
|
|
// If it is, then we need to send the full room state down (and 'limited' is always true).
|
|
// * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
|
|
// * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
|
|
// - Get all CURRENTLY joined rooms, and add them to 'joined' block.
|
|
|
|
// Look up all memberships for the user. We only care about rooms that a
|
|
// user has ever interacted with — joined to, kicked/banned from, left.
|
|
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
|
|
allRoomIDs := make([]string, 0, len(memberships))
|
|
joinedRoomIDs := make([]string, 0, len(memberships))
|
|
for roomID, membership := range memberships {
|
|
allRoomIDs = append(allRoomIDs, roomID)
|
|
if membership == gomatrixserverlib.Join {
|
|
joinedRoomIDs = append(joinedRoomIDs, roomID)
|
|
}
|
|
}
|
|
|
|
// get all the state events ever (i.e. for all available rooms) between these two positions
|
|
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, nil, allRoomIDs)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
|
|
// get all the state events ever (i.e. for all available rooms) between these two positions
|
|
stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
|
|
// find out which rooms this user is peeking, if any.
|
|
// We do this before joins so any peeks get overwritten
|
|
peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// add peek blocks
|
|
for _, peek := range peeks {
|
|
if peek.New {
|
|
// send full room state down instead of a delta
|
|
var s []types.StreamEvent
|
|
s, err = d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
continue
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
state[peek.RoomID] = s
|
|
}
|
|
if !peek.Deleted {
|
|
deltas = append(deltas, types.StateDelta{
|
|
Membership: gomatrixserverlib.Peek,
|
|
StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]),
|
|
RoomID: peek.RoomID,
|
|
})
|
|
}
|
|
}
|
|
|
|
// handle newly joined rooms and non-joined rooms
|
|
newlyJoinedRooms := make(map[string]bool, len(state))
|
|
for roomID, stateStreamEvents := range state {
|
|
for _, ev := range stateStreamEvents {
|
|
// Look for our membership in the state events and skip over any
|
|
// membership events that are not related to us.
|
|
membership, prevMembership := getMembershipFromEvent(ev.Event, userID)
|
|
if membership == "" {
|
|
continue
|
|
}
|
|
|
|
if membership == gomatrixserverlib.Join {
|
|
// If our membership is now join but the previous membership wasn't
|
|
// then this is a "join transition", so we'll insert this room.
|
|
if prevMembership != membership {
|
|
newlyJoinedRooms[roomID] = true
|
|
// Get the full room state, as we'll send that down for a newly
|
|
// joined room instead of a delta.
|
|
var s []types.StreamEvent
|
|
if s, err = d.currentStateStreamEventsForRoom(ctx, roomID, stateFilter); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
continue
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Add the information for this room into the state so that
|
|
// it will get added with all of the rest of the joined rooms.
|
|
stateFiltered[roomID] = s
|
|
}
|
|
|
|
// We won't add joined rooms into the delta at this point as they
|
|
// are added later on.
|
|
continue
|
|
}
|
|
|
|
deltas = append(deltas, types.StateDelta{
|
|
Membership: membership,
|
|
MembershipPos: ev.StreamPosition,
|
|
StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]),
|
|
RoomID: roomID,
|
|
})
|
|
break
|
|
}
|
|
}
|
|
|
|
// Finally, add in currently joined rooms, including those from the
|
|
// join transitions above.
|
|
for _, joinedRoomID := range joinedRoomIDs {
|
|
deltas = append(deltas, types.StateDelta{
|
|
Membership: gomatrixserverlib.Join,
|
|
StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]),
|
|
RoomID: joinedRoomID,
|
|
NewlyJoined: newlyJoinedRooms[joinedRoomID],
|
|
})
|
|
}
|
|
|
|
return deltas, joinedRoomIDs, nil
|
|
}
|
|
|
|
// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
|
|
// requests with full_state=true.
|
|
// Fetches full state for all joined rooms and uses selectStateInRange to get
|
|
// updates for other rooms.
|
|
func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
|
|
ctx context.Context, device *userapi.Device,
|
|
r types.Range, userID string,
|
|
stateFilter *gomatrixserverlib.StateFilter,
|
|
) ([]types.StateDelta, []string, error) {
|
|
// Look up all memberships for the user. We only care about rooms that a
|
|
// user has ever interacted with — joined to, kicked/banned from, left.
|
|
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
|
|
allRoomIDs := make([]string, 0, len(memberships))
|
|
joinedRoomIDs := make([]string, 0, len(memberships))
|
|
for roomID, membership := range memberships {
|
|
allRoomIDs = append(allRoomIDs, roomID)
|
|
if membership == gomatrixserverlib.Join {
|
|
joinedRoomIDs = append(joinedRoomIDs, roomID)
|
|
}
|
|
}
|
|
|
|
// Use a reasonable initial capacity
|
|
deltas := make(map[string]types.StateDelta)
|
|
|
|
peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Add full states for all peeking rooms
|
|
for _, peek := range peeks {
|
|
if !peek.Deleted {
|
|
s, stateErr := d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
|
|
if stateErr != nil {
|
|
if stateErr == sql.ErrNoRows {
|
|
continue
|
|
}
|
|
return nil, nil, stateErr
|
|
}
|
|
deltas[peek.RoomID] = types.StateDelta{
|
|
Membership: gomatrixserverlib.Peek,
|
|
StateEvents: d.StreamEventsToEvents(device, s),
|
|
RoomID: peek.RoomID,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get all the state events ever between these two positions
|
|
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
|
|
for roomID, stateStreamEvents := range state {
|
|
for _, ev := range stateStreamEvents {
|
|
if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" {
|
|
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
|
|
deltas[roomID] = types.StateDelta{
|
|
Membership: membership,
|
|
MembershipPos: ev.StreamPosition,
|
|
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
|
|
RoomID: roomID,
|
|
}
|
|
}
|
|
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add full states for all joined rooms
|
|
for _, joinedRoomID := range joinedRoomIDs {
|
|
s, stateErr := d.currentStateStreamEventsForRoom(ctx, joinedRoomID, stateFilter)
|
|
if stateErr != nil {
|
|
if stateErr == sql.ErrNoRows {
|
|
continue
|
|
}
|
|
return nil, nil, stateErr
|
|
}
|
|
deltas[joinedRoomID] = types.StateDelta{
|
|
Membership: gomatrixserverlib.Join,
|
|
StateEvents: d.StreamEventsToEvents(device, s),
|
|
RoomID: joinedRoomID,
|
|
}
|
|
}
|
|
|
|
// Create a response array.
|
|
result := make([]types.StateDelta, len(deltas))
|
|
i := 0
|
|
for _, delta := range deltas {
|
|
result[i] = delta
|
|
i++
|
|
}
|
|
|
|
return result, joinedRoomIDs, nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) currentStateStreamEventsForRoom(
|
|
ctx context.Context, roomID string,
|
|
stateFilter *gomatrixserverlib.StateFilter,
|
|
) ([]types.StreamEvent, error) {
|
|
allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s := make([]types.StreamEvent, len(allState))
|
|
for i := 0; i < len(s); i++ {
|
|
s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0}
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) SendToDeviceUpdatesForSync(
|
|
ctx context.Context,
|
|
userID, deviceID string,
|
|
from, to types.StreamPosition,
|
|
) (types.StreamPosition, []types.SendToDeviceEvent, error) {
|
|
// First of all, get our send-to-device updates for this user.
|
|
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, d.txn, userID, deviceID, from, to)
|
|
if err != nil {
|
|
return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
|
}
|
|
// If there's nothing to do then stop here.
|
|
if len(events) == 0 {
|
|
return to, nil, nil
|
|
}
|
|
return lastPos, events, nil
|
|
}
|
|
|
|
func (d *DatabaseTransaction) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) {
|
|
_, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
|
|
return receipts, err
|
|
}
|
|
|
|
func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
|
|
roomIDs := make([]string, 0, len(rooms))
|
|
for roomID, membership := range rooms {
|
|
if membership != gomatrixserverlib.Join {
|
|
continue
|
|
}
|
|
roomIDs = append(roomIDs, roomID)
|
|
}
|
|
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
|
|
return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
|
|
return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
|
|
return d.Presence.GetMaxPresenceID(ctx, d.txn)
|
|
}
|
|
|
|
func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) {
|
|
id, err := d.Relations.SelectMaxRelationID(ctx, d.txn)
|
|
return types.StreamPosition(id), err
|
|
}
|
|
|
|
func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (
|
|
events []types.StreamEvent, prevBatch, nextBatch string, err error,
|
|
) {
|
|
r := types.Range{
|
|
From: from,
|
|
To: to,
|
|
Backwards: backwards,
|
|
}
|
|
|
|
if r.Backwards && r.From == 0 {
|
|
// If we're working backwards (dir=b) and there's no ?from= specified then
|
|
// we will automatically want to work backwards from the current position,
|
|
// so find out what that is.
|
|
if r.From, err = d.MaxStreamPositionForRelations(ctx); err != nil {
|
|
return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err)
|
|
}
|
|
// The result normally isn't inclusive of the event *at* the ?from=
|
|
// position, so add 1 here so that we include the most recent relation.
|
|
r.From++
|
|
} else if !r.Backwards && r.To == 0 {
|
|
// If we're working forwards (dir=f) and there's no ?to= specified then
|
|
// we will automatically want to work forwards towards the current position,
|
|
// so find out what that is.
|
|
if r.To, err = d.MaxStreamPositionForRelations(ctx); err != nil {
|
|
return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err)
|
|
}
|
|
}
|
|
|
|
// First look up any relations from the database. We add one to the limit here
|
|
// so that we can tell if we're overflowing, as we will only set the "next_batch"
|
|
// in the response if we are.
|
|
relations, _, err := d.Relations.SelectRelationsInRange(ctx, d.txn, roomID, eventID, relType, eventType, r, limit+1)
|
|
if err != nil {
|
|
return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err)
|
|
}
|
|
|
|
// If we specified a relation type then just get those results, otherwise collate
|
|
// them from all of the returned relation types.
|
|
entries := []types.RelationEntry{}
|
|
if relType != "" {
|
|
entries = relations[relType]
|
|
} else {
|
|
for _, e := range relations {
|
|
entries = append(entries, e...)
|
|
}
|
|
}
|
|
|
|
// If there were no entries returned, there were no relations, so stop at this point.
|
|
if len(entries) == 0 {
|
|
return nil, "", "", nil
|
|
}
|
|
|
|
// Otherwise, let's try and work out what sensible prev_batch and next_batch values
|
|
// could be. We've requested an extra event by adding one to the limit already so
|
|
// that we can determine whether or not to provide a "next_batch", so trim off that
|
|
// event off the end if needs be.
|
|
if len(entries) > limit {
|
|
entries = entries[:len(entries)-1]
|
|
nextBatch = fmt.Sprintf("%d", entries[len(entries)-1].Position)
|
|
}
|
|
// TODO: set prevBatch? doesn't seem to affect the tests...
|
|
|
|
// Extract all of the event IDs from the relation entries so that we can pull the
|
|
// events out of the database. Then go and fetch the events.
|
|
eventIDs := make([]string, 0, len(entries))
|
|
for _, entry := range entries {
|
|
eventIDs = append(eventIDs, entry.EventID)
|
|
}
|
|
events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true)
|
|
if err != nil {
|
|
return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err)
|
|
}
|
|
|
|
return events, prevBatch, nextBatch, nil
|
|
}
|