mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 00:03:09 -06:00
Add auto_join for local rooms
This commit is contained in:
parent
c320261ca2
commit
8a7ccd5d1d
|
|
@ -44,6 +44,7 @@ type EventRelationshipRequest struct {
|
|||
IncludeChildren *bool `json:"include_children"`
|
||||
Direction string `json:"direction"`
|
||||
Batch string `json:"batch"`
|
||||
AutoJoin bool `json:"auto_join"`
|
||||
}
|
||||
|
||||
func (r *EventRelationshipRequest) applyDefaults() {
|
||||
|
|
@ -104,6 +105,13 @@ func Enable(base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, us
|
|||
return nil
|
||||
}
|
||||
|
||||
type reqCtx struct {
|
||||
ctx context.Context
|
||||
rsAPI roomserver.RoomserverInternalAPI
|
||||
req *EventRelationshipRequest
|
||||
userID string
|
||||
}
|
||||
|
||||
func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
|
||||
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
var relation EventRelationshipRequest
|
||||
|
|
@ -118,9 +126,15 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
|||
relation.applyDefaults()
|
||||
var res EventRelationshipResponse
|
||||
var returnEvents []*gomatrixserverlib.HeaderedEvent
|
||||
rc := reqCtx{
|
||||
ctx: req.Context(),
|
||||
req: &relation,
|
||||
userID: device.UserID,
|
||||
rsAPI: rsAPI,
|
||||
}
|
||||
|
||||
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
|
||||
event := getEventIfVisible(req.Context(), rsAPI, relation.EventID, device.UserID)
|
||||
event := rc.getEventIfVisible(relation.EventID)
|
||||
if event == nil {
|
||||
return util.JSONResponse{
|
||||
Code: 403,
|
||||
|
|
@ -132,7 +146,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
|||
returnEvents = append(returnEvents, event)
|
||||
|
||||
if *relation.IncludeParent {
|
||||
if parentEvent := includeParent(req.Context(), rsAPI, event, device.UserID); parentEvent != nil {
|
||||
if parentEvent := rc.includeParent(event); parentEvent != nil {
|
||||
returnEvents = append(returnEvents, parentEvent)
|
||||
}
|
||||
}
|
||||
|
|
@ -140,7 +154,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
|||
if *relation.IncludeChildren {
|
||||
remaining := relation.Limit - len(returnEvents)
|
||||
if remaining > 0 {
|
||||
children, resErr := includeChildren(req.Context(), rsAPI, db, event.EventID(), remaining, *relation.RecentFirst, device.UserID)
|
||||
children, resErr := rc.includeChildren(db, event.EventID(), remaining, *relation.RecentFirst)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
|
@ -157,7 +171,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
|||
}
|
||||
var events []*gomatrixserverlib.HeaderedEvent
|
||||
events, walkLimited = walkThread(
|
||||
req.Context(), db, rsAPI, device.UserID, &relation, included, remaining,
|
||||
req.Context(), db, &rc, included, remaining,
|
||||
)
|
||||
returnEvents = append(returnEvents, events...)
|
||||
}
|
||||
|
|
@ -176,27 +190,27 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
|||
|
||||
// If include_parent: true and there is a valid m.relationship field in the event,
|
||||
// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array.
|
||||
func includeParent(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, event *gomatrixserverlib.HeaderedEvent, userID string) (parent *gomatrixserverlib.HeaderedEvent) {
|
||||
func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) {
|
||||
parentID, _, _ := parentChildEventIDs(event)
|
||||
if parentID == "" {
|
||||
return nil
|
||||
}
|
||||
return getEventIfVisible(ctx, rsAPI, parentID, userID)
|
||||
return rc.getEventIfVisible(parentID)
|
||||
}
|
||||
|
||||
// If include_children: true, lookup all events which have event_id as an m.relationship
|
||||
// Apply history visibility checks to all these events and add the ones which pass into the response array,
|
||||
// honouring the recent_first flag and the limit.
|
||||
func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, db Database, parentID string, limit int, recentFirst bool, userID string) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
|
||||
children, err := db.ChildrenForParent(ctx, parentID, constRelType, recentFirst)
|
||||
func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
|
||||
children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("failed to get ChildrenForParent")
|
||||
util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent")
|
||||
resErr := jsonerror.InternalServerError()
|
||||
return nil, &resErr
|
||||
}
|
||||
var childEvents []*gomatrixserverlib.HeaderedEvent
|
||||
for _, child := range children {
|
||||
childEvent := getEventIfVisible(ctx, rsAPI, child.EventID, userID)
|
||||
childEvent := rc.getEventIfVisible(child.EventID)
|
||||
if childEvent != nil {
|
||||
childEvents = append(childEvents, childEvent)
|
||||
}
|
||||
|
|
@ -211,16 +225,16 @@ func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI
|
|||
// honouring the limit, max_depth and max_breadth values according to the following rules
|
||||
// nolint: unparam
|
||||
func walkThread(
|
||||
ctx context.Context, db Database, rsAPI roomserver.RoomserverInternalAPI, userID string, req *EventRelationshipRequest, included map[string]bool, limit int,
|
||||
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
|
||||
) ([]*gomatrixserverlib.HeaderedEvent, bool) {
|
||||
if req.Direction != "down" {
|
||||
if rc.req.Direction != "down" {
|
||||
util.GetLogger(ctx).Error("not implemented: direction=up")
|
||||
return nil, false
|
||||
}
|
||||
var result []*gomatrixserverlib.HeaderedEvent
|
||||
eventWalker := walker{
|
||||
ctx: ctx,
|
||||
req: req,
|
||||
req: rc.req,
|
||||
db: db,
|
||||
fn: func(wi *walkInfo) bool {
|
||||
// If already processed event, skip.
|
||||
|
|
@ -234,7 +248,7 @@ func walkThread(
|
|||
}
|
||||
|
||||
// Process the event.
|
||||
event := getEventIfVisible(ctx, rsAPI, wi.EventID, userID)
|
||||
event := rc.getEventIfVisible(wi.EventID)
|
||||
if event != nil {
|
||||
result = append(result, event)
|
||||
}
|
||||
|
|
@ -242,24 +256,24 @@ func walkThread(
|
|||
return false
|
||||
},
|
||||
}
|
||||
limited, err := eventWalker.WalkFrom(req.EventID)
|
||||
limited, err := eventWalker.WalkFrom(rc.req.EventID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", req.EventID)
|
||||
util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", rc.req.EventID)
|
||||
}
|
||||
return result, limited
|
||||
}
|
||||
|
||||
func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) *gomatrixserverlib.HeaderedEvent {
|
||||
func (rc *reqCtx) getEventIfVisible(eventID string) *gomatrixserverlib.HeaderedEvent {
|
||||
var queryEventsRes roomserver.QueryEventsByIDResponse
|
||||
err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{
|
||||
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
|
||||
EventIDs: []string{eventID},
|
||||
}, &queryEventsRes)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID")
|
||||
util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID")
|
||||
return nil
|
||||
}
|
||||
if len(queryEventsRes.Events) == 0 {
|
||||
util.GetLogger(ctx).Infof("event does not exist")
|
||||
util.GetLogger(rc.ctx).Infof("event does not exist")
|
||||
return nil // event does not exist
|
||||
}
|
||||
event := queryEventsRes.Events[0]
|
||||
|
|
@ -268,20 +282,34 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA
|
|||
// TODO: This does not honour history_visibility
|
||||
// TODO: This does not honour m.room.create content
|
||||
var queryMembershipRes roomserver.QueryMembershipForUserResponse
|
||||
err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{
|
||||
err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
|
||||
RoomID: event.RoomID(),
|
||||
UserID: userID,
|
||||
UserID: rc.userID,
|
||||
}, &queryMembershipRes)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
|
||||
util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
|
||||
return nil
|
||||
}
|
||||
if !queryMembershipRes.IsInRoom {
|
||||
util.GetLogger(ctx).Infof("user not in room")
|
||||
if queryMembershipRes.IsInRoom {
|
||||
return &event
|
||||
}
|
||||
if rc.req.AutoJoin {
|
||||
var joinRes roomserver.PerformJoinResponse
|
||||
rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{
|
||||
UserID: rc.userID,
|
||||
Content: map[string]interface{}{},
|
||||
RoomIDOrAlias: event.RoomID(),
|
||||
// TODO: Add server_names from linked room, currently this join will only work if the HS is already in the room
|
||||
}, &joinRes)
|
||||
if joinRes.Error != nil {
|
||||
util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", event.RoomID()).Error("Failed to auto-join room")
|
||||
return nil
|
||||
}
|
||||
return &event
|
||||
}
|
||||
util.GetLogger(rc.ctx).Infof("user not in room and auto_join disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
type walkInfo struct {
|
||||
eventInfo
|
||||
|
|
|
|||
|
|
@ -353,6 +353,10 @@ func TestMSC2836(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// TODO: TestMSC2836TerminatesLoops (short and long)
|
||||
// TODO: TestMSC2836UnknownEventsSkipped
|
||||
// TODO: TestMSC2836SkipEventIfNotInRoom
|
||||
|
||||
func runServer(t *testing.T, router *mux.Router) func() {
|
||||
t.Helper()
|
||||
externalServ := &http.Server{
|
||||
|
|
|
|||
Loading…
Reference in a new issue