Update history visibility checks to use gmsl

Update tests
This commit is contained in:
Till Faelligen 2022-07-19 12:37:16 +02:00
parent 1c43dbcc12
commit 910bd9b4a8
2 changed files with 59 additions and 105 deletions

View file

@ -16,73 +16,52 @@ package internal
import ( import (
"context" "context"
"database/sql"
"fmt"
"math" "math"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type HistoryVisibility string var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{
gomatrixserverlib.WorldReadable: 0,
const ( gomatrixserverlib.HistoryVisibilityShared: 1,
WorldReadable HistoryVisibility = "world_readable" gomatrixserverlib.HistoryVisibilityInvited: 2,
Joined HistoryVisibility = "joined" gomatrixserverlib.HistoryVisibilityJoined: 3,
Shared HistoryVisibility = "shared"
Default HistoryVisibility = "default"
Invited HistoryVisibility = "invited"
)
var historyVisibilityPriority = map[HistoryVisibility]uint8{
WorldReadable: 0,
Shared: 1,
Default: 1, // as per the spec, default == shared
Invited: 2,
Joined: 3,
} }
// EventVisibility contains the history visibility and membership state at a given event // eventVisibility contains the history visibility and membership state at a given event
type EventVisibility struct { type eventVisibility struct {
Visibility HistoryVisibility visibility gomatrixserverlib.HistoryVisibility
MembershipAtEvent string membershipAtEvent string
MembershipCurrent string membershipCurrent string
} }
// Visibility is a map from event_id to EvVis, which contains the history visibility and membership for a given user. // allowed checks the eventVisibility if the user is allowed to see the event.
type Visibility map[string]EventVisibility func (ev eventVisibility) allowed() (allowed bool) {
switch ev.visibility {
// allowed checks the Visibility map if the user is allowed to see the given event. case gomatrixserverlib.HistoryVisibilityWorldReadable:
func (v Visibility) allowed(eventID string) (allowed bool) {
ev, ok := v[eventID]
if !ok {
return false
}
switch ev.Visibility {
case WorldReadable:
// If the history_visibility was set to world_readable, allow. // If the history_visibility was set to world_readable, allow.
return true return true
case Joined: case gomatrixserverlib.HistoryVisibilityJoined:
// If the users membership was join, allow. // If the users membership was join, allow.
if ev.MembershipAtEvent == gomatrixserverlib.Join { if ev.membershipAtEvent == gomatrixserverlib.Join {
return true return true
} }
return false return false
case Shared, Default: case gomatrixserverlib.HistoryVisibilityShared:
// If the users membership was join, allow. // If the users membership was join, allow.
// If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow. // If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow.
if ev.MembershipAtEvent == gomatrixserverlib.Join || ev.MembershipCurrent == gomatrixserverlib.Join { if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join {
return true return true
} }
return false return false
case Invited: case gomatrixserverlib.HistoryVisibilityInvited:
// If the users membership was join, allow. // If the users membership was join, allow.
if ev.MembershipAtEvent == gomatrixserverlib.Join { if ev.membershipAtEvent == gomatrixserverlib.Join {
return true return true
} }
if ev.MembershipAtEvent == gomatrixserverlib.Invite { if ev.membershipAtEvent == gomatrixserverlib.Invite {
return true return true
} }
return false return false
@ -101,11 +80,22 @@ func ApplyHistoryVisibilityFilter(
userID string, userID string,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events))
stateForEvents, err := getStateForEvents(ctx, syncDB, events, userID) if len(events) == 0 {
if err != nil { return events, nil
return eventsFiltered, err
} }
// try to get the current membership of the user
membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64)
if err != nil {
return nil, err
}
for _, ev := range events { for _, ev := range events {
event, err := visibilityForEvent(ctx, syncDB, ev, userID)
if err != nil {
return eventsFiltered, err
}
event.membershipCurrent = membershipCurrent
// Always include specific state events for /sync responses // Always include specific state events for /sync responses
if alwaysIncludeEventIDs != nil { if alwaysIncludeEventIDs != nil {
if _, ok := alwaysIncludeEventIDs[ev.EventID()]; ok { if _, ok := alwaysIncludeEventIDs[ev.EventID()]; ok {
@ -121,73 +111,35 @@ func ApplyHistoryVisibilityFilter(
// Handle history visibility changes // Handle history visibility changes
if hisVis, err := ev.HistoryVisibility(); err == nil { if hisVis, err := ev.HistoryVisibility(); err == nil {
prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String() prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String()
if oldPrio, ok := historyVisibilityPriority[HistoryVisibility(prevHisVis)]; ok { if oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)]; ok {
// no OK check, since this should have been validated when setting the value // no OK check, since this should have been validated when setting the value
newPrio := historyVisibilityPriority[HistoryVisibility(hisVis)] newPrio := historyVisibilityPriority[hisVis]
if oldPrio < newPrio { if oldPrio < newPrio {
sfe := stateForEvents[ev.EventID()] event.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis)
sfe.Visibility = HistoryVisibility(prevHisVis)
stateForEvents[ev.EventID()] = sfe
} }
} }
} }
// do the actual check // do the actual check
if stateForEvents.allowed(ev.EventID()) { allowed := event.allowed()
if allowed {
eventsFiltered = append(eventsFiltered, ev) eventsFiltered = append(eventsFiltered, ev)
} }
} }
return eventsFiltered, nil return eventsFiltered, nil
} }
// getStateForEvents returns a Visibility map containing the state before and at the given events. // visibilityForEvent returns an eventVisibility containing the visibility and the membership at the given event.
func getStateForEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent, userID string) (Visibility, error) { // Returns an error if the database returns an error.
result := make(map[string]EventVisibility, len(events)) func visibilityForEvent(ctx context.Context, db storage.Database, event *gomatrixserverlib.HeaderedEvent, userID string) (eventVisibility, error) {
if len(events) == 0 { // get the membership event
return result, nil var membershipAtEvent string
} membershipAtEvent, _, err := db.SelectMembershipForUser(ctx, event.RoomID(), userID, event.Depth())
var (
membershipCurrent string
err error
)
// try to get the current membership of the user
membershipCurrent, _, err = db.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64)
if err != nil { if err != nil {
return nil, err return eventVisibility{}, err
} }
for _, ev := range events { return eventVisibility{
// get the event topology position visibility: event.Visibility,
pos, err := db.EventPositionInTopology(ctx, ev.EventID()) membershipAtEvent: membershipAtEvent,
if err != nil { }, nil
return nil, fmt.Errorf("initial event does not exist: %w", err)
}
// By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared
var hisVis = gomatrixserverlib.HistoryVisibilityShared
historyEvent, _, err := db.SelectTopologicalEvent(ctx, int(pos.Depth), "m.room.history_visibility", ev.RoomID())
if err != nil {
if err != sql.ErrNoRows {
return nil, err
}
logrus.WithError(err).Debugf("unable to get history event, defaulting to %s", Shared)
} else {
hisVis, err = historyEvent.HistoryVisibility()
if err != nil {
hisVis = gomatrixserverlib.HistoryVisibilityShared
}
}
// get the membership event
var membership string
membership, _, err = db.SelectMembershipForUser(ctx, ev.RoomID(), userID, int64(pos.Depth))
if err != nil {
return nil, err
}
// finally create the mapping
result[ev.EventID()] = EventVisibility{
Visibility: HistoryVisibility(hisVis),
MembershipAtEvent: membership,
MembershipCurrent: membershipCurrent,
}
}
return result, nil
} }

View file

@ -349,7 +349,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
wantResult result wantResult result
}{ }{
{ {
historyVisibility: "world_readable", historyVisibility: gomatrixserverlib.HistoryVisibilityWorldReadable,
wantResult: result{ wantResult: result{
seeWithoutJoin: true, seeWithoutJoin: true,
seeBeforeJoin: true, seeBeforeJoin: true,
@ -357,7 +357,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
}, },
}, },
{ {
historyVisibility: "shared", historyVisibility: gomatrixserverlib.HistoryVisibilityShared,
wantResult: result{ wantResult: result{
seeWithoutJoin: false, seeWithoutJoin: false,
seeBeforeJoin: true, seeBeforeJoin: true,
@ -365,7 +365,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
}, },
}, },
{ {
historyVisibility: "invited", historyVisibility: gomatrixserverlib.HistoryVisibilityInvited,
wantResult: result{ wantResult: result{
seeWithoutJoin: false, seeWithoutJoin: false,
seeBeforeJoin: false, seeBeforeJoin: false,
@ -373,7 +373,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
}, },
}, },
{ {
historyVisibility: "joined", historyVisibility: gomatrixserverlib.HistoryVisibilityJoined,
wantResult: result{ wantResult: result{
seeWithoutJoin: false, seeWithoutJoin: false,
seeBeforeJoin: false, seeBeforeJoin: false,
@ -390,6 +390,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType) base, close := testrig.CreateBaseDendrite(t, dbType)
defer close() defer close()
_ = close
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
@ -406,7 +407,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)}) beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)})
testrig.MustPublishMsgs(t, jsctx, toNATSMsgs(t, base, room.Events()...)...) testrig.MustPublishMsgs(t, jsctx, toNATSMsgs(t, base, room.Events()...)...)
testrig.MustPublishMsgs(t, jsctx, toNATSMsgs(t, base, beforeJoinEv)...) testrig.MustPublishMsgs(t, jsctx, toNATSMsgs(t, base, beforeJoinEv)...)
time.Sleep(100 * time.Millisecond) time.Sleep(200 * time.Millisecond)
// There is only one event, we expect only to be able to see this, if the room is world_readable // There is only one event, we expect only to be able to see this, if the room is world_readable
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -438,7 +439,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)}), room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)}),
) )
testrig.MustPublishMsgs(t, jsctx, msgs...) testrig.MustPublishMsgs(t, jsctx, msgs...)
time.Sleep(time.Millisecond * 100) time.Sleep(200 * time.Millisecond)
// Verify the messages after/before invite are visible or not // Verify the messages after/before invite are visible or not
w = httptest.NewRecorder() w = httptest.NewRecorder()
@ -495,6 +496,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverl
NewRoomEvent: &rsapi.OutputNewRoomEvent{ NewRoomEvent: &rsapi.OutputNewRoomEvent{
Event: ev, Event: ev,
AddsStateEventIDs: addsStateIDs, AddsStateEventIDs: addsStateIDs,
HistoryVisibility: ev.Visibility,
}, },
}) })
} }