diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index 52a816f89..98fc8bfc4 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -16,73 +16,52 @@ package internal import ( "context" - "database/sql" - "fmt" "math" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) -type HistoryVisibility string - -const ( - WorldReadable HistoryVisibility = "world_readable" - Joined HistoryVisibility = "joined" - Shared HistoryVisibility = "shared" - Default HistoryVisibility = "default" - Invited HistoryVisibility = "invited" -) - -var historyVisibilityPriority = map[HistoryVisibility]uint8{ - WorldReadable: 0, - Shared: 1, - Default: 1, // as per the spec, default == shared - Invited: 2, - Joined: 3, +var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{ + gomatrixserverlib.WorldReadable: 0, + gomatrixserverlib.HistoryVisibilityShared: 1, + gomatrixserverlib.HistoryVisibilityInvited: 2, + gomatrixserverlib.HistoryVisibilityJoined: 3, } -// EventVisibility contains the history visibility and membership state at a given event -type EventVisibility struct { - Visibility HistoryVisibility - MembershipAtEvent string - MembershipCurrent string +// eventVisibility contains the history visibility and membership state at a given event +type eventVisibility struct { + visibility gomatrixserverlib.HistoryVisibility + membershipAtEvent string + membershipCurrent string } -// Visibility is a map from event_id to EvVis, which contains the history visibility and membership for a given user. -type Visibility map[string]EventVisibility - -// allowed checks the Visibility map if the user is allowed to see the given event. -func (v Visibility) allowed(eventID string) (allowed bool) { - ev, ok := v[eventID] - if !ok { - return false - } - switch ev.Visibility { - case WorldReadable: +// allowed checks the eventVisibility if the user is allowed to see the event. +func (ev eventVisibility) allowed() (allowed bool) { + switch ev.visibility { + case gomatrixserverlib.HistoryVisibilityWorldReadable: // If the history_visibility was set to world_readable, allow. return true - case Joined: + case gomatrixserverlib.HistoryVisibilityJoined: // If the user’s membership was join, allow. - if ev.MembershipAtEvent == gomatrixserverlib.Join { + if ev.membershipAtEvent == gomatrixserverlib.Join { return true } return false - case Shared, Default: + case gomatrixserverlib.HistoryVisibilityShared: // If the user’s membership was join, allow. // If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow. - if ev.MembershipAtEvent == gomatrixserverlib.Join || ev.MembershipCurrent == gomatrixserverlib.Join { + if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join { return true } return false - case Invited: + case gomatrixserverlib.HistoryVisibilityInvited: // If the user’s membership was join, allow. - if ev.MembershipAtEvent == gomatrixserverlib.Join { + if ev.membershipAtEvent == gomatrixserverlib.Join { return true } - if ev.MembershipAtEvent == gomatrixserverlib.Invite { + if ev.membershipAtEvent == gomatrixserverlib.Invite { return true } return false @@ -101,11 +80,22 @@ func ApplyHistoryVisibilityFilter( userID string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) - stateForEvents, err := getStateForEvents(ctx, syncDB, events, userID) - if err != nil { - return eventsFiltered, err + if len(events) == 0 { + return events, nil } + + // try to get the current membership of the user + membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64) + if err != nil { + return nil, err + } + for _, ev := range events { + event, err := visibilityForEvent(ctx, syncDB, ev, userID) + if err != nil { + return eventsFiltered, err + } + event.membershipCurrent = membershipCurrent // Always include specific state events for /sync responses if alwaysIncludeEventIDs != nil { if _, ok := alwaysIncludeEventIDs[ev.EventID()]; ok { @@ -121,73 +111,35 @@ func ApplyHistoryVisibilityFilter( // Handle history visibility changes if hisVis, err := ev.HistoryVisibility(); err == nil { prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String() - if oldPrio, ok := historyVisibilityPriority[HistoryVisibility(prevHisVis)]; ok { + if oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)]; ok { // no OK check, since this should have been validated when setting the value - newPrio := historyVisibilityPriority[HistoryVisibility(hisVis)] + newPrio := historyVisibilityPriority[hisVis] if oldPrio < newPrio { - sfe := stateForEvents[ev.EventID()] - sfe.Visibility = HistoryVisibility(prevHisVis) - stateForEvents[ev.EventID()] = sfe + event.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis) } } } // do the actual check - if stateForEvents.allowed(ev.EventID()) { + allowed := event.allowed() + if allowed { eventsFiltered = append(eventsFiltered, ev) } } return eventsFiltered, nil } -// getStateForEvents returns a Visibility map containing the state before and at the given events. -func getStateForEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent, userID string) (Visibility, error) { - result := make(map[string]EventVisibility, len(events)) - if len(events) == 0 { - return result, nil - } - var ( - membershipCurrent string - err error - ) - // try to get the current membership of the user - membershipCurrent, _, err = db.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64) +// visibilityForEvent returns an eventVisibility containing the visibility and the membership at the given event. +// Returns an error if the database returns an error. +func visibilityForEvent(ctx context.Context, db storage.Database, event *gomatrixserverlib.HeaderedEvent, userID string) (eventVisibility, error) { + // get the membership event + var membershipAtEvent string + membershipAtEvent, _, err := db.SelectMembershipForUser(ctx, event.RoomID(), userID, event.Depth()) if err != nil { - return nil, err + return eventVisibility{}, err } - for _, ev := range events { - // get the event topology position - pos, err := db.EventPositionInTopology(ctx, ev.EventID()) - if err != nil { - return nil, fmt.Errorf("initial event does not exist: %w", err) - } - // By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared - var hisVis = gomatrixserverlib.HistoryVisibilityShared - historyEvent, _, err := db.SelectTopologicalEvent(ctx, int(pos.Depth), "m.room.history_visibility", ev.RoomID()) - if err != nil { - if err != sql.ErrNoRows { - return nil, err - } - logrus.WithError(err).Debugf("unable to get history event, defaulting to %s", Shared) - } else { - hisVis, err = historyEvent.HistoryVisibility() - if err != nil { - hisVis = gomatrixserverlib.HistoryVisibilityShared - } - } - // get the membership event - var membership string - membership, _, err = db.SelectMembershipForUser(ctx, ev.RoomID(), userID, int64(pos.Depth)) - if err != nil { - return nil, err - } - // finally create the mapping - result[ev.EventID()] = EventVisibility{ - Visibility: HistoryVisibility(hisVis), - MembershipAtEvent: membership, - MembershipCurrent: membershipCurrent, - } - } - - return result, nil + return eventVisibility{ + visibility: event.Visibility, + membershipAtEvent: membershipAtEvent, + }, nil } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index ef1f87d3d..abb75e667 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -349,7 +349,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { wantResult result }{ { - historyVisibility: "world_readable", + historyVisibility: gomatrixserverlib.HistoryVisibilityWorldReadable, wantResult: result{ seeWithoutJoin: true, seeBeforeJoin: true, @@ -357,7 +357,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { }, }, { - historyVisibility: "shared", + historyVisibility: gomatrixserverlib.HistoryVisibilityShared, wantResult: result{ seeWithoutJoin: false, seeBeforeJoin: true, @@ -365,7 +365,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { }, }, { - historyVisibility: "invited", + historyVisibility: gomatrixserverlib.HistoryVisibilityInvited, wantResult: result{ seeWithoutJoin: false, seeBeforeJoin: false, @@ -373,7 +373,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { }, }, { - historyVisibility: "joined", + historyVisibility: gomatrixserverlib.HistoryVisibilityJoined, wantResult: result{ seeWithoutJoin: false, seeBeforeJoin: false, @@ -390,6 +390,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { base, close := testrig.CreateBaseDendrite(t, dbType) defer close() + _ = close jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) @@ -406,7 +407,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)}) testrig.MustPublishMsgs(t, jsctx, toNATSMsgs(t, base, room.Events()...)...) testrig.MustPublishMsgs(t, jsctx, toNATSMsgs(t, base, beforeJoinEv)...) - time.Sleep(100 * time.Millisecond) + time.Sleep(200 * time.Millisecond) // There is only one event, we expect only to be able to see this, if the room is world_readable w := httptest.NewRecorder() @@ -438,7 +439,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)}), ) testrig.MustPublishMsgs(t, jsctx, msgs...) - time.Sleep(time.Millisecond * 100) + time.Sleep(200 * time.Millisecond) // Verify the messages after/before invite are visible or not w = httptest.NewRecorder() @@ -495,6 +496,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverl NewRoomEvent: &rsapi.OutputNewRoomEvent{ Event: ev, AddsStateEventIDs: addsStateIDs, + HistoryVisibility: ev.Visibility, }, }) }