diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index b0c17c3cf..27370e41d 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -54,7 +54,6 @@ type EventRelationshipRequest struct { IncludeChildren bool `json:"include_children"` Direction string `json:"direction"` Batch string `json:"batch"` - AutoJoin bool `json:"auto_join"` } func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) { @@ -93,7 +92,6 @@ func toClientResponse(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) } // Enable this MSC -// nolint:gocyclo func Enable( base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI, userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, @@ -140,12 +138,12 @@ func Enable( } type reqCtx struct { - ctx context.Context - rsAPI roomserver.RoomserverInternalAPI - db Database - req *EventRelationshipRequest - userID string - authorisedRoomIDs map[string]gomatrixserverlib.RoomVersion // events from these rooms can be returned TODO remove + ctx context.Context + rsAPI roomserver.RoomserverInternalAPI + db Database + req *EventRelationshipRequest + userID string + roomVersion gomatrixserverlib.RoomVersion // federated request args isFederatedRequest bool @@ -171,7 +169,6 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP fsAPI: fsAPI, isFederatedRequest: false, db: db, - authorisedRoomIDs: make(map[string]gomatrixserverlib.RoomVersion), } res, resErr := rc.process() if resErr != nil { @@ -197,11 +194,10 @@ func federatedEventRelationship( } } rc := reqCtx{ - ctx: ctx, - req: relation, - rsAPI: rsAPI, - db: db, - authorisedRoomIDs: make(map[string]gomatrixserverlib.RoomVersion), + ctx: ctx, + req: relation, + rsAPI: rsAPI, + db: db, // federation args isFederatedRequest: true, fsAPI: fsAPI, @@ -242,6 +238,7 @@ func federatedEventRelationship( } } +// nolint:gocyclo func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) { var res gomatrixserverlib.MSC2836EventRelationshipsResponse var returnEvents []*gomatrixserverlib.HeaderedEvent @@ -250,12 +247,16 @@ func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsRespons if event == nil { event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) } + if rc.req.RoomID == "" && event != nil { + rc.req.RoomID = event.RoomID() + } if event == nil || !rc.authorisedToSeeEvent(event) { return nil, &util.JSONResponse{ Code: 403, JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), } } + rc.roomVersion = event.Version() // Retrieve the event. Add it to response array. returnEvents = append(returnEvents, event) @@ -383,7 +384,7 @@ func (rc *reqCtx) includeParent(childEvent *gomatrixserverlib.HeaderedEvent) (pa func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { if rc.hasUnexploredChildren(parentID) { // we need to do a remote request to pull in the children as we are missing them locally. - _, roomVer, serversToQuery := rc.getServersForEventID(parentID) + serversToQuery := rc.getServersForEventID(parentID) var result *gomatrixserverlib.MSC2836EventRelationshipsResponse for _, srv := range serversToQuery { res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ @@ -393,7 +394,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen MaxBreadth: -1, MaxDepth: 1, // we just want the children from this parent RecentFirst: true, - }, roomVer) + }, rc.roomVersion) if err != nil { util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("includeChildren: failed to call MSC2836EventRelationships") } else { @@ -427,7 +428,6 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen // Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag, // honouring the limit, max_depth and max_breadth values according to the following rules -// nolint: unparam func walkThread( ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, ) ([]*gomatrixserverlib.HeaderedEvent, bool) { @@ -486,13 +486,9 @@ func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverli } -// authorisedToSeeEvent authenticates that the user or server is allowed to see this event. Returns true if allowed to -// see this request. +// authorisedToSeeEvent checks that the user or server is allowed to see this event. Returns true if allowed to +// see this request. This only needs to be done once per room at present as we just check for joined status. func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) bool { - authorised, ok := rc.authorisedRoomIDs[event.RoomID()] - if ok { - return len(authorised) > 0 - } if rc.isFederatedRequest { // make sure the server is in this room var res fs.QueryJoinedHostServerNamesInRoomResponse @@ -500,64 +496,70 @@ func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) b RoomID: event.RoomID(), }, &res) if err != nil { - util.GetLogger(rc.ctx).WithError(err).Error("authenticateEvent: failed to QueryJoinedHostServerNamesInRoom") + util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryJoinedHostServerNamesInRoom") return false } for _, srv := range res.ServerNames { if srv == rc.serverName { - rc.authorisedRoomIDs[event.RoomID()] = event.Version() return true } } return false } // make sure the user is in this room - joinedToRoom, err := rc.allowedToSeeEvent(event.RoomID(), rc.userID) - if err != nil || !joinedToRoom { + // Allow events if the member is in the room + // TODO: This does not honour history_visibility + // TODO: This does not honour m.room.create content + var queryMembershipRes roomserver.QueryMembershipForUserResponse + err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ + RoomID: event.RoomID(), + UserID: rc.userID, + }, &queryMembershipRes) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser") return false } - rc.authorisedRoomIDs[event.RoomID()] = event.Version() - return true + return queryMembershipRes.IsInRoom } -func (rc *reqCtx) getServersForEventID(eventID string) (string, gomatrixserverlib.RoomVersion, []gomatrixserverlib.ServerName) { - if len(rc.authorisedRoomIDs) != 1 { +func (rc *reqCtx) getServersForEventID(eventID string) []gomatrixserverlib.ServerName { + if rc.req.RoomID == "" { util.GetLogger(rc.ctx).WithField("event_id", eventID).Error( - "getServersForEventID: thread exists over multiple rooms and reached unknown event, cannot determine room and hence which servers to query", + "getServersForEventID: event exists in unknown room", ) - return "", "", nil + return nil } - var roomID string - var roomVer gomatrixserverlib.RoomVersion - for r, v := range rc.authorisedRoomIDs { - roomID = r - roomVer = v + if rc.roomVersion == "" { + util.GetLogger(rc.ctx).WithField("event_id", eventID).Errorf( + "getServersForEventID: event exists in %s with unknown room version", rc.req.RoomID, + ) + return nil } var queryRes fs.QueryJoinedHostServerNamesInRoomResponse err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: roomID, + RoomID: rc.req.RoomID, }, &queryRes) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("getServersForEventID: failed to QueryJoinedHostServerNamesInRoom") - return "", "", nil + return nil } // query up to 5 servers serversToQuery := queryRes.ServerNames if len(serversToQuery) > 5 { serversToQuery = serversToQuery[:5] } - return roomID, roomVer, serversToQuery + return serversToQuery } func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MSC2836EventRelationshipsResponse { if rc.isFederatedRequest { return nil // we don't query remote servers for remote requests } - _, roomVer, serversToQuery := rc.getServersForEventID(eventID) + serversToQuery := rc.getServersForEventID(eventID) var res *gomatrixserverlib.MSC2836EventRelationshipsResponse var err error for _, srv := range serversToQuery { - res, err = rc.MSC2836EventRelationships(eventID, srv, roomVer) + res, err = rc.MSC2836EventRelationships(eventID, srv, rc.roomVersion) if err != nil { util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("remoteEventRelationships: failed to call MSC2836EventRelationships") } else { @@ -569,7 +571,6 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MS // lookForEvent returns the event for the event ID given, by trying to query remote servers // if the event ID is unknown via /event_relationships. -// nolint:gocyclo func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { event := rc.getLocalEvent(eventID) if event == nil { @@ -578,7 +579,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent // inject all the events into the roomserver then return the event in question rc.injectResponseToRoomserver(queryRes) for _, ev := range queryRes.Events { - if ev.EventID() == eventID { + if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() { return ev.Headered(ev.Version()) } } @@ -595,46 +596,12 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent } } } - - if rc.authorisedToSeeEvent(event) { + if rc.req.RoomID == event.RoomID() { return event } - if !rc.isFederatedRequest && rc.req.AutoJoin { - // attempt to join the room then recheck auth, but only for local users - var joinRes roomserver.PerformJoinResponse - rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{ - UserID: rc.userID, - Content: map[string]interface{}{}, - RoomIDOrAlias: event.RoomID(), - }, &joinRes) - if joinRes.Error != nil { - util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", event.RoomID()).Error("Failed to auto-join room") - return nil - } - delete(rc.authorisedRoomIDs, event.RoomID()) - if rc.authorisedToSeeEvent(event) { - return event - } - } return nil } -func (rc *reqCtx) allowedToSeeEvent(roomID, userID string) (bool, error) { - // Allow events if the member is in the room - // TODO: This does not honour history_visibility - // TODO: This does not honour m.room.create content - var queryMembershipRes roomserver.QueryMembershipForUserResponse - err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ - RoomID: roomID, - UserID: userID, - }, &queryMembershipRes) - if err != nil { - util.GetLogger(rc.ctx).WithError(err).Error("allowedToSeeEvent: failed to QueryMembershipForUser") - return false, err - } - return queryMembershipRes.IsInRoom, nil -} - func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent { var queryEventsRes roomserver.QueryEventsByIDResponse err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go index b56fd234d..eb1da4ffe 100644 --- a/internal/mscs/msc2836/msc2836_test.go +++ b/internal/mscs/msc2836/msc2836_test.go @@ -47,9 +47,7 @@ func TestMSC2836(t *testing.T) { alice := "@alice:localhost" bob := "@bob:localhost" charlie := "@charlie:localhost" - roomIDA := "!alice:localhost" - roomIDB := "!bob:localhost" - roomIDC := "!charlie:localhost" + roomID := "!alice:localhost" // give access tokens to all three users nopUserAPI := &testUserAPI{ accessTokens: make(map[string]userapi.Device), @@ -70,7 +68,7 @@ func TestMSC2836(t *testing.T) { UserID: charlie, } eventA := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDA, + RoomID: roomID, Sender: alice, Type: "m.room.message", Content: map[string]interface{}{ @@ -78,7 +76,7 @@ func TestMSC2836(t *testing.T) { }, }) eventB := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDB, + RoomID: roomID, Sender: bob, Type: "m.room.message", Content: map[string]interface{}{ @@ -90,7 +88,7 @@ func TestMSC2836(t *testing.T) { }, }) eventC := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDB, + RoomID: roomID, Sender: bob, Type: "m.room.message", Content: map[string]interface{}{ @@ -102,7 +100,7 @@ func TestMSC2836(t *testing.T) { }, }) eventD := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDA, + RoomID: roomID, Sender: alice, Type: "m.room.message", Content: map[string]interface{}{ @@ -114,7 +112,7 @@ func TestMSC2836(t *testing.T) { }, }) eventE := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDB, + RoomID: roomID, Sender: bob, Type: "m.room.message", Content: map[string]interface{}{ @@ -126,7 +124,7 @@ func TestMSC2836(t *testing.T) { }, }) eventF := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDC, + RoomID: roomID, Sender: charlie, Type: "m.room.message", Content: map[string]interface{}{ @@ -138,7 +136,7 @@ func TestMSC2836(t *testing.T) { }, }) eventG := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDA, + RoomID: roomID, Sender: alice, Type: "m.room.message", Content: map[string]interface{}{ @@ -150,7 +148,7 @@ func TestMSC2836(t *testing.T) { }, }) eventH := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDB, + RoomID: roomID, Sender: bob, Type: "m.room.message", Content: map[string]interface{}{ @@ -164,9 +162,9 @@ func TestMSC2836(t *testing.T) { // make everyone joined to each other's rooms nopRsAPI := &testRoomserverAPI{ userToJoinedRooms: map[string][]string{ - alice: []string{roomIDA, roomIDB, roomIDC}, - bob: []string{roomIDA, roomIDB, roomIDC}, - charlie: []string{roomIDA, roomIDB, roomIDC}, + alice: []string{roomID}, + bob: []string{roomID}, + charlie: []string{roomID}, }, events: map[string]*gomatrixserverlib.HeaderedEvent{ eventA.EventID(): eventA, @@ -202,21 +200,6 @@ func TestMSC2836(t *testing.T) { "include_parent": true, })) }) - t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) { - nopUserAPI.accessTokens["frank2"] = userapi.Device{ - AccessToken: "frank2", - DisplayName: "Frank2 Not In Room", - UserID: "@frank2:localhost", - } - // Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB - nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB} - body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{ - "event_id": eventB.EventID(), - "limit": 1, - "include_parent": true, - })) - assertContains(t, body, []string{eventB.EventID()}) - }) t.Run("returns the parent if include_parent is true", func(t *testing.T) { body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ "event_id": eventB.EventID(),