mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-17 03:43:11 -06:00
Make backfill work for shared history visibility
This commit is contained in:
parent
a4b9edb28e
commit
9bdf8465e3
|
|
@ -27,7 +27,7 @@ func IsServerAllowed(
|
||||||
serverCurrentlyInRoom bool,
|
serverCurrentlyInRoom bool,
|
||||||
authEvents []gomatrixserverlib.Event,
|
authEvents []gomatrixserverlib.Event,
|
||||||
) bool {
|
) bool {
|
||||||
historyVisibility := historyVisibilityForRoom(authEvents)
|
historyVisibility := HistoryVisibilityForRoom(authEvents)
|
||||||
|
|
||||||
// 1. If the history_visibility was set to world_readable, allow.
|
// 1. If the history_visibility was set to world_readable, allow.
|
||||||
if historyVisibility == "world_readable" {
|
if historyVisibility == "world_readable" {
|
||||||
|
|
@ -52,7 +52,7 @@ func IsServerAllowed(
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func historyVisibilityForRoom(authEvents []gomatrixserverlib.Event) string {
|
func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.Event) string {
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.0#id87
|
// https://matrix.org/docs/spec/client_server/r0.6.0#id87
|
||||||
// By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared.
|
// By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared.
|
||||||
visibility := "shared"
|
visibility := "shared"
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package query
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/auth"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
@ -114,7 +115,9 @@ func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrix
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
|
func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
|
||||||
|
event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
|
||||||
|
|
||||||
// try to fetch the events from the database first
|
// try to fetch the events from the database first
|
||||||
events, err := b.ProvideEvents(roomVer, eventIDs)
|
events, err := b.ProvideEvents(roomVer, eventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -152,6 +155,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
|
||||||
// will be servers that are in the room already. The entries at the beginning are preferred servers
|
// will be servers that are in the room already. The entries at the beginning are preferred servers
|
||||||
// and will be tried first. An empty list will fail the request.
|
// and will be tried first. An empty list will fail the request.
|
||||||
func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) (servers []gomatrixserverlib.ServerName) {
|
func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) (servers []gomatrixserverlib.ServerName) {
|
||||||
|
logrus.Infof("ServersAtEvent room %s event %d", roomID, eventID)
|
||||||
// getMembershipsBeforeEventNID requires a NID, so retrieving the NID for
|
// getMembershipsBeforeEventNID requires a NID, so retrieving the NID for
|
||||||
// the event is necessary.
|
// the event is necessary.
|
||||||
NIDs, err := b.db.EventNIDs(ctx, []string{eventID})
|
NIDs, err := b.db.EventNIDs(ctx, []string{eventID})
|
||||||
|
|
@ -160,18 +164,33 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID])
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// possibly return all joined servers depending on history visiblity
|
||||||
|
memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis))
|
||||||
|
|
||||||
// Retrieve all "m.room.member" state events of "join" membership, which
|
// Retrieve all "m.room.member" state events of "join" membership, which
|
||||||
// contains the list of users in the room before the event, therefore all
|
// contains the list of users in the room before the event, therefore all
|
||||||
// the servers in it at that moment.
|
// the servers in it at that moment.
|
||||||
events, err := getMembershipsBeforeEventNID(ctx, b.db, NIDs[eventID], true)
|
memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
|
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
memberEvents = append(memberEvents, memberEventsFromVis...)
|
||||||
|
|
||||||
// Store the server names in a temporary map to avoid duplicates.
|
// Store the server names in a temporary map to avoid duplicates.
|
||||||
serverSet := make(map[gomatrixserverlib.ServerName]bool)
|
serverSet := make(map[gomatrixserverlib.ServerName]bool)
|
||||||
for _, event := range events {
|
for _, event := range memberEvents {
|
||||||
serverSet[event.Origin()] = true
|
serverSet[event.Origin()] = true
|
||||||
}
|
}
|
||||||
for server := range serverSet {
|
for server := range serverSet {
|
||||||
|
|
@ -186,7 +205,9 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID
|
||||||
|
|
||||||
// Backfill performs a backfill request to the given server.
|
// Backfill performs a backfill request to the given server.
|
||||||
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
|
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
|
||||||
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) {
|
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string,
|
||||||
|
fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) {
|
||||||
|
|
||||||
tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs)
|
tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs)
|
||||||
return &tx, err
|
return &tx, err
|
||||||
}
|
}
|
||||||
|
|
@ -215,3 +236,44 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
|
||||||
}
|
}
|
||||||
return events, nil
|
return events, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility.
|
||||||
|
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
|
||||||
|
// pull all events and then filter by that table.
|
||||||
|
func joinEventsFromHistoryVisibility(
|
||||||
|
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) {
|
||||||
|
|
||||||
|
var eventNIDs []types.EventNID
|
||||||
|
for _, entry := range stateEntries {
|
||||||
|
// Filter the events to retrieve to only keep the membership events
|
||||||
|
if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID {
|
||||||
|
eventNIDs = append(eventNIDs, entry.EventNID)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all of the events in this state
|
||||||
|
stateEvents, err := db.Events(ctx, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
events := make([]gomatrixserverlib.Event, len(stateEvents))
|
||||||
|
for i := range stateEvents {
|
||||||
|
events[i] = stateEvents[i].Event
|
||||||
|
}
|
||||||
|
visibility := auth.HistoryVisibilityForRoom(events)
|
||||||
|
if visibility != "shared" {
|
||||||
|
logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// get joined members
|
||||||
|
roomNID, err := db.RoomNID(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return db.Events(ctx, joinEventNIDs)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -286,7 +286,12 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
||||||
|
|
||||||
events, err = r.DB.Events(ctx, eventNIDs)
|
events, err = r.DB.Events(ctx, eventNIDs)
|
||||||
} else {
|
} else {
|
||||||
events, err = getMembershipsBeforeEventNID(ctx, r.DB, membershipEventNID, request.JoinedOnly)
|
stateEntries, err := stateBeforeEvent(ctx, r.DB, membershipEventNID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -301,15 +306,8 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMembershipsBeforeEventNID takes the numeric ID of an event and fetches the state
|
func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) {
|
||||||
// of the event's room as it was when this event was fired, then filters the state events to
|
|
||||||
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
|
||||||
// Returns an error if there was an issue fetching the events.
|
|
||||||
func getMembershipsBeforeEventNID(
|
|
||||||
ctx context.Context, db storage.Database, eventNID types.EventNID, joinedOnly bool,
|
|
||||||
) ([]types.Event, error) {
|
|
||||||
roomState := state.NewStateResolution(db)
|
roomState := state.NewStateResolution(db)
|
||||||
events := []types.Event{}
|
|
||||||
// Lookup the event NID
|
// Lookup the event NID
|
||||||
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -323,10 +321,15 @@ func getMembershipsBeforeEventNID(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the state as it was when this event was fired
|
// Fetch the state as it was when this event was fired
|
||||||
stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
||||||
if err != nil {
|
}
|
||||||
return nil, err
|
|
||||||
}
|
// getMembershipsAtState filters the state events to
|
||||||
|
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
||||||
|
// Returns an error if there was an issue fetching the events.
|
||||||
|
func getMembershipsAtState(
|
||||||
|
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
|
||||||
|
) ([]types.Event, error) {
|
||||||
|
|
||||||
var eventNIDs []types.EventNID
|
var eventNIDs []types.EventNID
|
||||||
for _, entry := range stateEntries {
|
for _, entry := range stateEntries {
|
||||||
|
|
@ -347,6 +350,7 @@ func getMembershipsBeforeEventNID(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter the events to only keep the "join" membership events
|
// Filter the events to only keep the "join" membership events
|
||||||
|
var events []types.Event
|
||||||
for _, event := range stateEvents {
|
for _, event := range stateEvents {
|
||||||
membership, err := event.Membership()
|
membership, err := event.Membership()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -563,20 +567,23 @@ func (r *RoomserverQueryAPI) backfillViaFederation(ctx context.Context, req *api
|
||||||
if !ok {
|
if !ok {
|
||||||
// this should be impossible as all events returned must have pass Step 5 of the PDU checks
|
// this should be impossible as all events returned must have pass Step 5 of the PDU checks
|
||||||
// which requires a list of state IDs.
|
// which requires a list of state IDs.
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to find state IDs for event which passed auth checks")
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var entries []types.StateEntry
|
var entries []types.StateEntry
|
||||||
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
|
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var beforeStateSnapshotNID types.StateSnapshotNID
|
var beforeStateSnapshotNID types.StateSnapshotNID
|
||||||
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
|
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
util.GetLogger(ctx).Infof("Backfilled event %s (nid=%d) getting snapshot %v with entries %+v", ev.EventID(), ev.EventNID, beforeStateSnapshotNID, entries)
|
||||||
if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
|
if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to set state before event")
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -857,7 +864,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
|
||||||
var stateAtEvent types.StateAtEvent
|
var stateAtEvent types.StateAtEvent
|
||||||
roomNID, stateAtEvent, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
|
roomNID, stateAtEvent, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to store backfilled event")
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
backfilledEventMap[ev.EventID()] = types.Event{
|
backfilledEventMap[ev.EventID()] = types.Event{
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,10 @@ func (v StateResolution) LoadStateAtEvent(
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
|
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err)
|
||||||
|
}
|
||||||
|
if snapshotNID == 0 {
|
||||||
|
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
|
stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
|
||||||
|
|
|
||||||
|
|
@ -48,11 +48,6 @@ const insertEventSQL = `
|
||||||
ON CONFLICT DO NOTHING;
|
ON CONFLICT DO NOTHING;
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertEventResultSQL = `
|
|
||||||
SELECT event_nid, state_snapshot_nid FROM roomserver_events
|
|
||||||
WHERE rowid = last_insert_rowid();
|
|
||||||
`
|
|
||||||
|
|
||||||
const selectEventSQL = "" +
|
const selectEventSQL = "" +
|
||||||
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
|
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
|
||||||
|
|
||||||
|
|
@ -126,7 +121,6 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
||||||
return statementList{
|
return statementList{
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.insertEventResultStmt, insertEventResultSQL},
|
|
||||||
{&s.selectEventStmt, selectEventSQL},
|
{&s.selectEventStmt, selectEventSQL},
|
||||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||||
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||||
|
|
@ -153,18 +147,22 @@ func (s *eventStatements) insertEvent(
|
||||||
authEventNIDs []types.EventNID,
|
authEventNIDs []types.EventNID,
|
||||||
depth int64,
|
depth int64,
|
||||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||||
var eventNID int64
|
// attempt to insert: the last_row_id is the event NID
|
||||||
var stateNID int64
|
|
||||||
var err error
|
|
||||||
insertStmt := common.TxStmt(txn, s.insertEventStmt)
|
insertStmt := common.TxStmt(txn, s.insertEventStmt)
|
||||||
resultStmt := common.TxStmt(txn, s.insertEventResultStmt)
|
result, err := insertStmt.ExecContext(
|
||||||
if _, err = insertStmt.ExecContext(
|
|
||||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||||
); err == nil {
|
)
|
||||||
err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
modified, err := result.RowsAffected()
|
||||||
|
if modified == 0 && err == nil {
|
||||||
|
return 0, 0, sql.ErrNoRows
|
||||||
|
}
|
||||||
|
// the snapshot will always be 0 at this point
|
||||||
|
eventNID, err := result.LastInsertId()
|
||||||
|
return types.EventNID(eventNID), types.StateSnapshotNID(0), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) selectEvent(
|
func (s *eventStatements) selectEvent(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue