Remove cross-room threading impl

This commit is contained in:
Kegan Dougal 2020-12-01 18:27:00 +00:00
parent 4a31e5bba7
commit fcf0451975
2 changed files with 59 additions and 109 deletions

View file

@ -54,7 +54,6 @@ type EventRelationshipRequest struct {
IncludeChildren bool `json:"include_children"` IncludeChildren bool `json:"include_children"`
Direction string `json:"direction"` Direction string `json:"direction"`
Batch string `json:"batch"` Batch string `json:"batch"`
AutoJoin bool `json:"auto_join"`
} }
func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) { func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) {
@ -93,7 +92,6 @@ func toClientResponse(res *gomatrixserverlib.MSC2836EventRelationshipsResponse)
} }
// Enable this MSC // Enable this MSC
// nolint:gocyclo
func Enable( func Enable(
base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI, base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI,
userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
@ -145,7 +143,7 @@ type reqCtx struct {
db Database db Database
req *EventRelationshipRequest req *EventRelationshipRequest
userID string userID string
authorisedRoomIDs map[string]gomatrixserverlib.RoomVersion // events from these rooms can be returned TODO remove roomVersion gomatrixserverlib.RoomVersion
// federated request args // federated request args
isFederatedRequest bool isFederatedRequest bool
@ -171,7 +169,6 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
fsAPI: fsAPI, fsAPI: fsAPI,
isFederatedRequest: false, isFederatedRequest: false,
db: db, db: db,
authorisedRoomIDs: make(map[string]gomatrixserverlib.RoomVersion),
} }
res, resErr := rc.process() res, resErr := rc.process()
if resErr != nil { if resErr != nil {
@ -201,7 +198,6 @@ func federatedEventRelationship(
req: relation, req: relation,
rsAPI: rsAPI, rsAPI: rsAPI,
db: db, db: db,
authorisedRoomIDs: make(map[string]gomatrixserverlib.RoomVersion),
// federation args // federation args
isFederatedRequest: true, isFederatedRequest: true,
fsAPI: fsAPI, fsAPI: fsAPI,
@ -242,6 +238,7 @@ func federatedEventRelationship(
} }
} }
// nolint:gocyclo
func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) { func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) {
var res gomatrixserverlib.MSC2836EventRelationshipsResponse var res gomatrixserverlib.MSC2836EventRelationshipsResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent var returnEvents []*gomatrixserverlib.HeaderedEvent
@ -250,12 +247,16 @@ func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsRespons
if event == nil { if event == nil {
event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID)
} }
if rc.req.RoomID == "" && event != nil {
rc.req.RoomID = event.RoomID()
}
if event == nil || !rc.authorisedToSeeEvent(event) { if event == nil || !rc.authorisedToSeeEvent(event) {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: 403, Code: 403,
JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"),
} }
} }
rc.roomVersion = event.Version()
// Retrieve the event. Add it to response array. // Retrieve the event. Add it to response array.
returnEvents = append(returnEvents, event) returnEvents = append(returnEvents, event)
@ -383,7 +384,7 @@ func (rc *reqCtx) includeParent(childEvent *gomatrixserverlib.HeaderedEvent) (pa
func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
if rc.hasUnexploredChildren(parentID) { if rc.hasUnexploredChildren(parentID) {
// we need to do a remote request to pull in the children as we are missing them locally. // we need to do a remote request to pull in the children as we are missing them locally.
_, roomVer, serversToQuery := rc.getServersForEventID(parentID) serversToQuery := rc.getServersForEventID(parentID)
var result *gomatrixserverlib.MSC2836EventRelationshipsResponse var result *gomatrixserverlib.MSC2836EventRelationshipsResponse
for _, srv := range serversToQuery { for _, srv := range serversToQuery {
res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
@ -393,7 +394,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
MaxBreadth: -1, MaxBreadth: -1,
MaxDepth: 1, // we just want the children from this parent MaxDepth: 1, // we just want the children from this parent
RecentFirst: true, RecentFirst: true,
}, roomVer) }, rc.roomVersion)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("includeChildren: failed to call MSC2836EventRelationships") util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("includeChildren: failed to call MSC2836EventRelationships")
} else { } else {
@ -427,7 +428,6 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
// Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag, // Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag,
// honouring the limit, max_depth and max_breadth values according to the following rules // honouring the limit, max_depth and max_breadth values according to the following rules
// nolint: unparam
func walkThread( func walkThread(
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
) ([]*gomatrixserverlib.HeaderedEvent, bool) { ) ([]*gomatrixserverlib.HeaderedEvent, bool) {
@ -486,13 +486,9 @@ func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverli
} }
// authorisedToSeeEvent authenticates that the user or server is allowed to see this event. Returns true if allowed to // authorisedToSeeEvent checks that the user or server is allowed to see this event. Returns true if allowed to
// see this request. // see this request. This only needs to be done once per room at present as we just check for joined status.
func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) bool { func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) bool {
authorised, ok := rc.authorisedRoomIDs[event.RoomID()]
if ok {
return len(authorised) > 0
}
if rc.isFederatedRequest { if rc.isFederatedRequest {
// make sure the server is in this room // make sure the server is in this room
var res fs.QueryJoinedHostServerNamesInRoomResponse var res fs.QueryJoinedHostServerNamesInRoomResponse
@ -500,64 +496,70 @@ func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) b
RoomID: event.RoomID(), RoomID: event.RoomID(),
}, &res) }, &res)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("authenticateEvent: failed to QueryJoinedHostServerNamesInRoom") util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryJoinedHostServerNamesInRoom")
return false return false
} }
for _, srv := range res.ServerNames { for _, srv := range res.ServerNames {
if srv == rc.serverName { if srv == rc.serverName {
rc.authorisedRoomIDs[event.RoomID()] = event.Version()
return true return true
} }
} }
return false return false
} }
// make sure the user is in this room // make sure the user is in this room
joinedToRoom, err := rc.allowedToSeeEvent(event.RoomID(), rc.userID) // Allow events if the member is in the room
if err != nil || !joinedToRoom { // TODO: This does not honour history_visibility
// TODO: This does not honour m.room.create content
var queryMembershipRes roomserver.QueryMembershipForUserResponse
err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: event.RoomID(),
UserID: rc.userID,
}, &queryMembershipRes)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser")
return false return false
} }
rc.authorisedRoomIDs[event.RoomID()] = event.Version() return queryMembershipRes.IsInRoom
return true
} }
func (rc *reqCtx) getServersForEventID(eventID string) (string, gomatrixserverlib.RoomVersion, []gomatrixserverlib.ServerName) { func (rc *reqCtx) getServersForEventID(eventID string) []gomatrixserverlib.ServerName {
if len(rc.authorisedRoomIDs) != 1 { if rc.req.RoomID == "" {
util.GetLogger(rc.ctx).WithField("event_id", eventID).Error( util.GetLogger(rc.ctx).WithField("event_id", eventID).Error(
"getServersForEventID: thread exists over multiple rooms and reached unknown event, cannot determine room and hence which servers to query", "getServersForEventID: event exists in unknown room",
) )
return "", "", nil return nil
} }
var roomID string if rc.roomVersion == "" {
var roomVer gomatrixserverlib.RoomVersion util.GetLogger(rc.ctx).WithField("event_id", eventID).Errorf(
for r, v := range rc.authorisedRoomIDs { "getServersForEventID: event exists in %s with unknown room version", rc.req.RoomID,
roomID = r )
roomVer = v return nil
} }
var queryRes fs.QueryJoinedHostServerNamesInRoomResponse var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
RoomID: roomID, RoomID: rc.req.RoomID,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("getServersForEventID: failed to QueryJoinedHostServerNamesInRoom") util.GetLogger(rc.ctx).WithError(err).Error("getServersForEventID: failed to QueryJoinedHostServerNamesInRoom")
return "", "", nil return nil
} }
// query up to 5 servers // query up to 5 servers
serversToQuery := queryRes.ServerNames serversToQuery := queryRes.ServerNames
if len(serversToQuery) > 5 { if len(serversToQuery) > 5 {
serversToQuery = serversToQuery[:5] serversToQuery = serversToQuery[:5]
} }
return roomID, roomVer, serversToQuery return serversToQuery
} }
func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MSC2836EventRelationshipsResponse { func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MSC2836EventRelationshipsResponse {
if rc.isFederatedRequest { if rc.isFederatedRequest {
return nil // we don't query remote servers for remote requests return nil // we don't query remote servers for remote requests
} }
_, roomVer, serversToQuery := rc.getServersForEventID(eventID) serversToQuery := rc.getServersForEventID(eventID)
var res *gomatrixserverlib.MSC2836EventRelationshipsResponse var res *gomatrixserverlib.MSC2836EventRelationshipsResponse
var err error var err error
for _, srv := range serversToQuery { for _, srv := range serversToQuery {
res, err = rc.MSC2836EventRelationships(eventID, srv, roomVer) res, err = rc.MSC2836EventRelationships(eventID, srv, rc.roomVersion)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("remoteEventRelationships: failed to call MSC2836EventRelationships") util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("remoteEventRelationships: failed to call MSC2836EventRelationships")
} else { } else {
@ -569,7 +571,6 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MS
// lookForEvent returns the event for the event ID given, by trying to query remote servers // lookForEvent returns the event for the event ID given, by trying to query remote servers
// if the event ID is unknown via /event_relationships. // if the event ID is unknown via /event_relationships.
// nolint:gocyclo
func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
event := rc.getLocalEvent(eventID) event := rc.getLocalEvent(eventID)
if event == nil { if event == nil {
@ -578,7 +579,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent
// inject all the events into the roomserver then return the event in question // inject all the events into the roomserver then return the event in question
rc.injectResponseToRoomserver(queryRes) rc.injectResponseToRoomserver(queryRes)
for _, ev := range queryRes.Events { for _, ev := range queryRes.Events {
if ev.EventID() == eventID { if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() {
return ev.Headered(ev.Version()) return ev.Headered(ev.Version())
} }
} }
@ -595,45 +596,11 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent
} }
} }
} }
if rc.req.RoomID == event.RoomID() {
if rc.authorisedToSeeEvent(event) {
return event return event
} }
if !rc.isFederatedRequest && rc.req.AutoJoin {
// attempt to join the room then recheck auth, but only for local users
var joinRes roomserver.PerformJoinResponse
rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{
UserID: rc.userID,
Content: map[string]interface{}{},
RoomIDOrAlias: event.RoomID(),
}, &joinRes)
if joinRes.Error != nil {
util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", event.RoomID()).Error("Failed to auto-join room")
return nil return nil
} }
delete(rc.authorisedRoomIDs, event.RoomID())
if rc.authorisedToSeeEvent(event) {
return event
}
}
return nil
}
func (rc *reqCtx) allowedToSeeEvent(roomID, userID string) (bool, error) {
// Allow events if the member is in the room
// TODO: This does not honour history_visibility
// TODO: This does not honour m.room.create content
var queryMembershipRes roomserver.QueryMembershipForUserResponse
err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: roomID,
UserID: userID,
}, &queryMembershipRes)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("allowedToSeeEvent: failed to QueryMembershipForUser")
return false, err
}
return queryMembershipRes.IsInRoom, nil
}
func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
var queryEventsRes roomserver.QueryEventsByIDResponse var queryEventsRes roomserver.QueryEventsByIDResponse

View file

@ -47,9 +47,7 @@ func TestMSC2836(t *testing.T) {
alice := "@alice:localhost" alice := "@alice:localhost"
bob := "@bob:localhost" bob := "@bob:localhost"
charlie := "@charlie:localhost" charlie := "@charlie:localhost"
roomIDA := "!alice:localhost" roomID := "!alice:localhost"
roomIDB := "!bob:localhost"
roomIDC := "!charlie:localhost"
// give access tokens to all three users // give access tokens to all three users
nopUserAPI := &testUserAPI{ nopUserAPI := &testUserAPI{
accessTokens: make(map[string]userapi.Device), accessTokens: make(map[string]userapi.Device),
@ -70,7 +68,7 @@ func TestMSC2836(t *testing.T) {
UserID: charlie, UserID: charlie,
} }
eventA := mustCreateEvent(t, fledglingEvent{ eventA := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDA, RoomID: roomID,
Sender: alice, Sender: alice,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -78,7 +76,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventB := mustCreateEvent(t, fledglingEvent{ eventB := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDB, RoomID: roomID,
Sender: bob, Sender: bob,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -90,7 +88,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventC := mustCreateEvent(t, fledglingEvent{ eventC := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDB, RoomID: roomID,
Sender: bob, Sender: bob,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -102,7 +100,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventD := mustCreateEvent(t, fledglingEvent{ eventD := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDA, RoomID: roomID,
Sender: alice, Sender: alice,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -114,7 +112,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventE := mustCreateEvent(t, fledglingEvent{ eventE := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDB, RoomID: roomID,
Sender: bob, Sender: bob,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -126,7 +124,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventF := mustCreateEvent(t, fledglingEvent{ eventF := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDC, RoomID: roomID,
Sender: charlie, Sender: charlie,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -138,7 +136,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventG := mustCreateEvent(t, fledglingEvent{ eventG := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDA, RoomID: roomID,
Sender: alice, Sender: alice,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -150,7 +148,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventH := mustCreateEvent(t, fledglingEvent{ eventH := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDB, RoomID: roomID,
Sender: bob, Sender: bob,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -164,9 +162,9 @@ func TestMSC2836(t *testing.T) {
// make everyone joined to each other's rooms // make everyone joined to each other's rooms
nopRsAPI := &testRoomserverAPI{ nopRsAPI := &testRoomserverAPI{
userToJoinedRooms: map[string][]string{ userToJoinedRooms: map[string][]string{
alice: []string{roomIDA, roomIDB, roomIDC}, alice: []string{roomID},
bob: []string{roomIDA, roomIDB, roomIDC}, bob: []string{roomID},
charlie: []string{roomIDA, roomIDB, roomIDC}, charlie: []string{roomID},
}, },
events: map[string]*gomatrixserverlib.HeaderedEvent{ events: map[string]*gomatrixserverlib.HeaderedEvent{
eventA.EventID(): eventA, eventA.EventID(): eventA,
@ -202,21 +200,6 @@ func TestMSC2836(t *testing.T) {
"include_parent": true, "include_parent": true,
})) }))
}) })
t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) {
nopUserAPI.accessTokens["frank2"] = userapi.Device{
AccessToken: "frank2",
DisplayName: "Frank2 Not In Room",
UserID: "@frank2:localhost",
}
// Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB
nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB}
body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"limit": 1,
"include_parent": true,
}))
assertContains(t, body, []string{eventB.EventID()})
})
t.Run("returns the parent if include_parent is true", func(t *testing.T) { t.Run("returns the parent if include_parent is true", func(t *testing.T) {
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(), "event_id": eventB.EventID(),