Add actual history_visibility changes for /messages

This commit is contained in:
Till Faelligen 2022-06-02 12:23:29 +02:00
parent f594725b43
commit 706adc936b
3 changed files with 371 additions and 91 deletions

View file

@ -0,0 +1,138 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"database/sql"
"fmt"
"math"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
type HistoryVisibility string
const (
WorldReadable HistoryVisibility = "world_readable"
Joined HistoryVisibility = "joined"
Shared HistoryVisibility = "shared"
Default HistoryVisibility = "default"
Invited HistoryVisibility = "invited"
)
// EventVisibility contains the history visibility and membership state at a given event
type EventVisibility struct {
Visibility HistoryVisibility
MembershipAtEvent string
MembershipCurrent string
MembershipPosition int // the topological position of the membership event
HistoryPosition int // the topological position of the history event
}
// Visibility is a map from event_id to EvVis, which contains the history visibility and membership for a given user.
type Visibility map[string]EventVisibility
// Allowed checks the Visibility map if the user is allowed to see the given event.
func (v Visibility) Allowed(eventID string) (allowed bool) {
ev, ok := v[eventID]
if !ok {
return false
}
switch ev.Visibility {
case WorldReadable:
// If the history_visibility was set to world_readable, allow.
return true
case Joined:
// If the users membership was join, allow.
if ev.MembershipAtEvent == gomatrixserverlib.Join {
return true
}
return false
case Shared, Default:
// If the users membership was join, allow.
// If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow.
if ev.MembershipAtEvent == gomatrixserverlib.Join || ev.MembershipCurrent == gomatrixserverlib.Join {
return true
}
return false
case Invited:
// If the users membership was join, allow.
if ev.MembershipAtEvent == gomatrixserverlib.Join {
return true
}
if ev.MembershipAtEvent == gomatrixserverlib.Invite {
return true
}
return false
default:
return false
}
}
// GetStateForEvents returns a Visibility map containing the state before and at the given events.
func GetStateForEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.ClientEvent, userID string) (Visibility, error) {
result := make(map[string]EventVisibility, len(events))
var (
membershipCurrent string
err error
)
// try to get the current membership of the user
if len(events) > 0 {
membershipCurrent, _, err = db.SelectMembershipForUser(ctx, events[0].RoomID, userID, math.MaxInt64)
if err != nil {
return nil, err
}
}
for _, ev := range events {
// get the event topology position
pos, err := db.EventPositionInTopology(ctx, ev.EventID)
if err != nil {
return nil, fmt.Errorf("initial event does not exist: %w", err)
}
// By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared
var hisVis = "shared"
historyEvent, historyPos, err := db.SelectTopologicalEvent(ctx, int(pos.Depth), "m.room.history_visibility", ev.RoomID)
if err != nil {
if err != sql.ErrNoRows {
return nil, err
}
logrus.WithError(err).Debugf("unable to get history event, defaulting to %s", Shared)
} else {
hisVis, err = historyEvent.HistoryVisibility()
if err != nil {
hisVis = "shared"
}
}
// get the membership event
var membership string
membership, memberPos, err := db.SelectMembershipForUser(ctx, ev.RoomID, userID, int(pos.Depth))
if err != nil {
return nil, err
}
// finally create the mapping
result[ev.EventID] = EventVisibility{
Visibility: HistoryVisibility(hisVis),
MembershipAtEvent: membership,
MembershipCurrent: membershipCurrent,
MembershipPosition: memberPos,
HistoryPosition: int(historyPos.Depth),
}
}
return result, nil
}

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -47,7 +48,7 @@ type messagesReq struct {
filter *gomatrixserverlib.RoomEventFilter filter *gomatrixserverlib.RoomEventFilter
} }
type messagesResp struct { type MessageResp struct {
Start string `json:"start"` Start string `json:"start"`
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token
End string `json:"end"` End string `json:"end"`
@ -200,11 +201,25 @@ func OnIncomingMessagesRequest(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// apply history_visibility filter
clientEventsNew := []gomatrixserverlib.ClientEvent{}
var stateForEvents internal.Visibility
stateForEvents, err = internal.GetStateForEvents(req.Context(), db, clientEvents, device.UserID)
if err != nil {
logrus.WithError(err).Error("internal.GetStateForEvents failed")
return jsonerror.InternalServerError()
}
for _, ev := range clientEvents {
if stateForEvents.Allowed(ev.EventID) {
clientEventsNew = append(clientEventsNew, ev)
}
}
// at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set // at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set
state := []gomatrixserverlib.ClientEvent{} state := []gomatrixserverlib.ClientEvent{}
if filter.LazyLoadMembers { if filter.LazyLoadMembers {
membershipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent) membershipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent)
for _, evt := range clientEvents { for _, evt := range clientEventsNew {
// Don't add membership events the client should already know about // Don't add membership events the client should already know about
if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, roomID, evt.Sender); cached { if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, roomID, evt.Sender); cached {
continue continue
@ -224,6 +239,8 @@ func OnIncomingMessagesRequest(
} }
} }
logrus.Debugf("Events after filtering: %d vs %d", len(clientEvents), len(clientEventsNew))
util.GetLogger(req.Context()).WithFields(logrus.Fields{ util.GetLogger(req.Context()).WithFields(logrus.Fields{
"from": from.String(), "from": from.String(),
"to": to.String(), "to": to.String(),
@ -233,8 +250,8 @@ func OnIncomingMessagesRequest(
"return_end": end.String(), "return_end": end.String(),
}).Info("Responding") }).Info("Responding")
res := messagesResp{ res := MessageResp{
Chunk: clientEvents, Chunk: clientEventsNew,
Start: start.String(), Start: start.String(),
End: end.String(), End: end.String(),
State: state, State: state,
@ -320,7 +337,6 @@ func (r *messagesReq) retrieveEvents() (
} }
events = reversed(events) events = reversed(events)
} }
events = r.filterHistoryVisible(events)
if len(events) == 0 { if len(events) == 0 {
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil
} }
@ -330,89 +346,6 @@ func (r *messagesReq) retrieveEvents() (
return clientEvents, start, end, err return clientEvents, start, end, err
} }
func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
// TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
// user shouldn't see, we check the recent events and remove any prior to the join event of the user
// which is equiv to history_visibility: joined
joinEventIndex := -1
for i, ev := range events {
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(r.device.UserID) {
membership, _ := ev.Membership()
if membership == "join" {
joinEventIndex = i
break
}
}
}
var result []*gomatrixserverlib.HeaderedEvent
var eventsToCheck []*gomatrixserverlib.HeaderedEvent
if joinEventIndex != -1 {
if r.backwardOrdering {
result = events[:joinEventIndex+1]
eventsToCheck = append(eventsToCheck, result[0])
} else {
result = events[joinEventIndex:]
eventsToCheck = append(eventsToCheck, result[len(result)-1])
}
} else {
eventsToCheck = []*gomatrixserverlib.HeaderedEvent{events[0], events[len(events)-1]}
result = events
}
// make sure the user was in the room for both the earliest and latest events, we need this because
// some backpagination results will not have the join event (e.g if they hit /messages at the join event itself)
wasJoined := true
for _, ev := range eventsToCheck {
var queryRes api.QueryStateAfterEventsResponse
err := r.rsAPI.QueryStateAfterEvents(r.ctx, &api.QueryStateAfterEventsRequest{
RoomID: ev.RoomID(),
PrevEventIDs: ev.PrevEventIDs(),
StateToFetch: []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomMember, StateKey: r.device.UserID},
{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""},
},
}, &queryRes)
if err != nil {
wasJoined = false
break
}
var hisVisEvent, membershipEvent *gomatrixserverlib.HeaderedEvent
for i := range queryRes.StateEvents {
switch queryRes.StateEvents[i].Type() {
case gomatrixserverlib.MRoomMember:
membershipEvent = queryRes.StateEvents[i]
case gomatrixserverlib.MRoomHistoryVisibility:
hisVisEvent = queryRes.StateEvents[i]
}
}
if hisVisEvent == nil {
return events // apply no filtering as it defaults to Shared.
}
hisVis, _ := hisVisEvent.HistoryVisibility()
if hisVis == "shared" || hisVis == "world_readable" {
return events // apply no filtering
}
if membershipEvent == nil {
wasJoined = false
break
}
membership, err := membershipEvent.Membership()
if err != nil {
wasJoined = false
break
}
if membership != "join" {
wasJoined = false
break
}
}
if !wasJoined {
util.GetLogger(r.ctx).WithField("num_events", len(events)).Warnf("%s was not joined to room during these events, omitting them", r.device.UserID)
return []*gomatrixserverlib.HeaderedEvent{}
}
return result
}
func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
if r.backwardOrdering { if r.backwardOrdering {
start = *r.from start = *r.from

View file

@ -3,6 +3,7 @@ package syncapi
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -13,6 +14,7 @@ import (
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/routing"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
@ -51,6 +53,12 @@ func (s *syncRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *rsap
return nil return nil
} }
func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsapi.QueryMembershipForUserRequest, res *rsapi.QueryMembershipForUserResponse) error {
res.IsRoomForgotten = false
res.RoomExists = true
return nil
}
type syncUserAPI struct { type syncUserAPI struct {
userapi.SyncUserAPI userapi.SyncUserAPI
accounts []userapi.Device accounts []userapi.Device
@ -103,7 +111,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
msgs := toNATSMsgs(t, base, room.Events()) msgs := toNATSMsgs(t, base, room.Events()...)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
testrig.MustPublishMsgs(t, jsctx, msgs...) testrig.MustPublishMsgs(t, jsctx, msgs...)
@ -196,7 +204,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
// m.room.power_levels // m.room.power_levels
// m.room.join_rules // m.room.join_rules
// m.room.history_visibility // m.room.history_visibility
msgs := toNATSMsgs(t, base, room.Events()) msgs := toNATSMsgs(t, base, room.Events()...)
sinceTokens := make([]string, len(msgs)) sinceTokens := make([]string, len(msgs))
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
for i, msg := range msgs { for i, msg := range msgs {
@ -311,7 +319,208 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
} }
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg { // This is mainly what Sytest is doing in "test_history_visibility"
func TestMessageHistoryVisibility(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testHistoryVisibility(t, dbType)
})
}
func testHistoryVisibility(t *testing.T, dbType test.DBType) {
type result struct {
seeWithoutJoin bool
seeBeforeJoin bool
seeAfterInvite bool
}
// create the users
alice := test.NewUser(t)
bob := test.NewUser(t)
bobDev := userapi.Device{
ID: "BOBID",
UserID: bob.ID,
AccessToken: "BOD_BEARER_TOKEN",
DisplayName: "BOB",
}
// check guest and normaler user accounts
for _, accType := range []userapi.AccountType{userapi.AccountTypeGuest, userapi.AccountTypeUser} {
testCases := []struct {
historyVisibility string
wantResult result
}{
{
historyVisibility: "world_readable",
wantResult: result{
seeWithoutJoin: true,
seeBeforeJoin: true,
seeAfterInvite: true,
},
},
{
historyVisibility: "shared",
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: true,
seeAfterInvite: true,
},
},
{
historyVisibility: "invited",
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: false,
seeAfterInvite: true,
},
},
{
historyVisibility: "joined",
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: false,
seeAfterInvite: false,
},
},
{
historyVisibility: "default",
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: true,
seeAfterInvite: true,
},
},
}
bobDev.AccountType = accType
userType := "guest"
if accType == userapi.AccountTypeUser {
userType = "real user"
}
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, &syncRoomserverAPI{}, &syncKeyAPI{})
for _, tc := range testCases {
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
t.Run(testname, func(t *testing.T) {
// create a room with the given visibility
room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility))
// send the events/messages to NATS to create the rooms
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)})
testrig.MustPublishMsgs(t, jsctx, toNATSMsgs(t, base, room.Events()...)...)
testrig.MustPublishMsgs(t, jsctx, toNATSMsgs(t, base, beforeJoinEv)...)
time.Sleep(100 * time.Millisecond)
// There is only one event, we expect only to be able to see this, if the room is world_readable
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
"dir": "b",
})))
if w.Code != 200 {
t.Logf("%s", w.Body.String())
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
var res routing.MessageResp
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
if tc.wantResult.seeWithoutJoin {
found := false
for _, ev := range res.Chunk {
if ev.EventID == beforeJoinEv.EventID() {
found = true
break
}
}
if !found {
t.Fatalf("expected to see event %s without joining but didn't: %+v", beforeJoinEv.EventID(), res.Chunk)
}
} else {
for _, ev := range res.Chunk {
if ev.EventID == beforeJoinEv.EventID() {
t.Fatalf("expected not to see event %s without joining: %+v", beforeJoinEv.EventID(), string(ev.Content))
}
}
}
// Create invite, a message, join the room and create another message.
msgs := toNATSMsgs(t, base, room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID)))
testrig.MustPublishMsgs(t, jsctx, msgs...)
afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)})
msgs = toNATSMsgs(t, base,
afterInviteEv,
room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID)),
room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)}),
)
testrig.MustPublishMsgs(t, jsctx, msgs...)
time.Sleep(time.Millisecond * 100)
// Verify the messages after/before invite are visible or not
w = httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
"dir": "b",
})))
if w.Code != 200 {
t.Logf("%s", w.Body.String())
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
// verify result for seeBeforeJoin
if tc.wantResult.seeBeforeJoin {
found := false
for _, ev := range res.Chunk {
if ev.EventID == beforeJoinEv.EventID() {
found = true
break
}
}
if !found {
t.Fatalf("expected to see event %s before joining but didn't: %+v", beforeJoinEv.EventID(), res.Chunk)
}
} else {
for _, ev := range res.Chunk {
if ev.EventID == beforeJoinEv.EventID() {
t.Fatalf("expected not to see event %s before joining: %+v", beforeJoinEv.EventID(), string(ev.Content))
}
}
}
// verify result for seeAfterInvite
if tc.wantResult.seeAfterInvite {
found := false
for _, ev := range res.Chunk {
if ev.EventID == afterInviteEv.EventID() {
found = true
break
}
}
if !found {
t.Fatalf("expected to see event %s after invite but didn't: %+v", afterInviteEv.EventID(), res.Chunk)
}
} else {
for _, ev := range res.Chunk {
if ev.EventID == afterInviteEv.EventID() {
t.Fatalf("expected not to see event %s after invite: %+v", afterInviteEv.EventID(), string(ev.Content))
}
}
}
})
}
}
}
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
result := make([]*nats.Msg, len(input)) result := make([]*nats.Msg, len(input))
for i, ev := range input { for i, ev := range input {
var addsStateIDs []string var addsStateIDs []string