mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 00:03:09 -06:00
Add auto_join for local rooms
This commit is contained in:
parent
c320261ca2
commit
8a7ccd5d1d
|
|
@ -44,6 +44,7 @@ type EventRelationshipRequest struct {
|
||||||
IncludeChildren *bool `json:"include_children"`
|
IncludeChildren *bool `json:"include_children"`
|
||||||
Direction string `json:"direction"`
|
Direction string `json:"direction"`
|
||||||
Batch string `json:"batch"`
|
Batch string `json:"batch"`
|
||||||
|
AutoJoin bool `json:"auto_join"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *EventRelationshipRequest) applyDefaults() {
|
func (r *EventRelationshipRequest) applyDefaults() {
|
||||||
|
|
@ -104,6 +105,13 @@ func Enable(base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, us
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type reqCtx struct {
|
||||||
|
ctx context.Context
|
||||||
|
rsAPI roomserver.RoomserverInternalAPI
|
||||||
|
req *EventRelationshipRequest
|
||||||
|
userID string
|
||||||
|
}
|
||||||
|
|
||||||
func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
|
func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
|
||||||
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
var relation EventRelationshipRequest
|
var relation EventRelationshipRequest
|
||||||
|
|
@ -118,9 +126,15 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
||||||
relation.applyDefaults()
|
relation.applyDefaults()
|
||||||
var res EventRelationshipResponse
|
var res EventRelationshipResponse
|
||||||
var returnEvents []*gomatrixserverlib.HeaderedEvent
|
var returnEvents []*gomatrixserverlib.HeaderedEvent
|
||||||
|
rc := reqCtx{
|
||||||
|
ctx: req.Context(),
|
||||||
|
req: &relation,
|
||||||
|
userID: device.UserID,
|
||||||
|
rsAPI: rsAPI,
|
||||||
|
}
|
||||||
|
|
||||||
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
|
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
|
||||||
event := getEventIfVisible(req.Context(), rsAPI, relation.EventID, device.UserID)
|
event := rc.getEventIfVisible(relation.EventID)
|
||||||
if event == nil {
|
if event == nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 403,
|
Code: 403,
|
||||||
|
|
@ -132,7 +146,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
||||||
returnEvents = append(returnEvents, event)
|
returnEvents = append(returnEvents, event)
|
||||||
|
|
||||||
if *relation.IncludeParent {
|
if *relation.IncludeParent {
|
||||||
if parentEvent := includeParent(req.Context(), rsAPI, event, device.UserID); parentEvent != nil {
|
if parentEvent := rc.includeParent(event); parentEvent != nil {
|
||||||
returnEvents = append(returnEvents, parentEvent)
|
returnEvents = append(returnEvents, parentEvent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -140,7 +154,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
||||||
if *relation.IncludeChildren {
|
if *relation.IncludeChildren {
|
||||||
remaining := relation.Limit - len(returnEvents)
|
remaining := relation.Limit - len(returnEvents)
|
||||||
if remaining > 0 {
|
if remaining > 0 {
|
||||||
children, resErr := includeChildren(req.Context(), rsAPI, db, event.EventID(), remaining, *relation.RecentFirst, device.UserID)
|
children, resErr := rc.includeChildren(db, event.EventID(), remaining, *relation.RecentFirst)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
|
@ -157,7 +171,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
||||||
}
|
}
|
||||||
var events []*gomatrixserverlib.HeaderedEvent
|
var events []*gomatrixserverlib.HeaderedEvent
|
||||||
events, walkLimited = walkThread(
|
events, walkLimited = walkThread(
|
||||||
req.Context(), db, rsAPI, device.UserID, &relation, included, remaining,
|
req.Context(), db, &rc, included, remaining,
|
||||||
)
|
)
|
||||||
returnEvents = append(returnEvents, events...)
|
returnEvents = append(returnEvents, events...)
|
||||||
}
|
}
|
||||||
|
|
@ -176,27 +190,27 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
||||||
|
|
||||||
// If include_parent: true and there is a valid m.relationship field in the event,
|
// If include_parent: true and there is a valid m.relationship field in the event,
|
||||||
// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array.
|
// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array.
|
||||||
func includeParent(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, event *gomatrixserverlib.HeaderedEvent, userID string) (parent *gomatrixserverlib.HeaderedEvent) {
|
func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) {
|
||||||
parentID, _, _ := parentChildEventIDs(event)
|
parentID, _, _ := parentChildEventIDs(event)
|
||||||
if parentID == "" {
|
if parentID == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return getEventIfVisible(ctx, rsAPI, parentID, userID)
|
return rc.getEventIfVisible(parentID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If include_children: true, lookup all events which have event_id as an m.relationship
|
// If include_children: true, lookup all events which have event_id as an m.relationship
|
||||||
// Apply history visibility checks to all these events and add the ones which pass into the response array,
|
// Apply history visibility checks to all these events and add the ones which pass into the response array,
|
||||||
// honouring the recent_first flag and the limit.
|
// honouring the recent_first flag and the limit.
|
||||||
func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, db Database, parentID string, limit int, recentFirst bool, userID string) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
|
func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
|
||||||
children, err := db.ChildrenForParent(ctx, parentID, constRelType, recentFirst)
|
children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("failed to get ChildrenForParent")
|
util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent")
|
||||||
resErr := jsonerror.InternalServerError()
|
resErr := jsonerror.InternalServerError()
|
||||||
return nil, &resErr
|
return nil, &resErr
|
||||||
}
|
}
|
||||||
var childEvents []*gomatrixserverlib.HeaderedEvent
|
var childEvents []*gomatrixserverlib.HeaderedEvent
|
||||||
for _, child := range children {
|
for _, child := range children {
|
||||||
childEvent := getEventIfVisible(ctx, rsAPI, child.EventID, userID)
|
childEvent := rc.getEventIfVisible(child.EventID)
|
||||||
if childEvent != nil {
|
if childEvent != nil {
|
||||||
childEvents = append(childEvents, childEvent)
|
childEvents = append(childEvents, childEvent)
|
||||||
}
|
}
|
||||||
|
|
@ -211,16 +225,16 @@ func includeChildren(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI
|
||||||
// honouring the limit, max_depth and max_breadth values according to the following rules
|
// honouring the limit, max_depth and max_breadth values according to the following rules
|
||||||
// nolint: unparam
|
// nolint: unparam
|
||||||
func walkThread(
|
func walkThread(
|
||||||
ctx context.Context, db Database, rsAPI roomserver.RoomserverInternalAPI, userID string, req *EventRelationshipRequest, included map[string]bool, limit int,
|
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
|
||||||
) ([]*gomatrixserverlib.HeaderedEvent, bool) {
|
) ([]*gomatrixserverlib.HeaderedEvent, bool) {
|
||||||
if req.Direction != "down" {
|
if rc.req.Direction != "down" {
|
||||||
util.GetLogger(ctx).Error("not implemented: direction=up")
|
util.GetLogger(ctx).Error("not implemented: direction=up")
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
var result []*gomatrixserverlib.HeaderedEvent
|
var result []*gomatrixserverlib.HeaderedEvent
|
||||||
eventWalker := walker{
|
eventWalker := walker{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
req: req,
|
req: rc.req,
|
||||||
db: db,
|
db: db,
|
||||||
fn: func(wi *walkInfo) bool {
|
fn: func(wi *walkInfo) bool {
|
||||||
// If already processed event, skip.
|
// If already processed event, skip.
|
||||||
|
|
@ -234,7 +248,7 @@ func walkThread(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the event.
|
// Process the event.
|
||||||
event := getEventIfVisible(ctx, rsAPI, wi.EventID, userID)
|
event := rc.getEventIfVisible(wi.EventID)
|
||||||
if event != nil {
|
if event != nil {
|
||||||
result = append(result, event)
|
result = append(result, event)
|
||||||
}
|
}
|
||||||
|
|
@ -242,24 +256,24 @@ func walkThread(
|
||||||
return false
|
return false
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
limited, err := eventWalker.WalkFrom(req.EventID)
|
limited, err := eventWalker.WalkFrom(rc.req.EventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", req.EventID)
|
util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", rc.req.EventID)
|
||||||
}
|
}
|
||||||
return result, limited
|
return result, limited
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) *gomatrixserverlib.HeaderedEvent {
|
func (rc *reqCtx) getEventIfVisible(eventID string) *gomatrixserverlib.HeaderedEvent {
|
||||||
var queryEventsRes roomserver.QueryEventsByIDResponse
|
var queryEventsRes roomserver.QueryEventsByIDResponse
|
||||||
err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{
|
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
|
||||||
EventIDs: []string{eventID},
|
EventIDs: []string{eventID},
|
||||||
}, &queryEventsRes)
|
}, &queryEventsRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID")
|
util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if len(queryEventsRes.Events) == 0 {
|
if len(queryEventsRes.Events) == 0 {
|
||||||
util.GetLogger(ctx).Infof("event does not exist")
|
util.GetLogger(rc.ctx).Infof("event does not exist")
|
||||||
return nil // event does not exist
|
return nil // event does not exist
|
||||||
}
|
}
|
||||||
event := queryEventsRes.Events[0]
|
event := queryEventsRes.Events[0]
|
||||||
|
|
@ -268,20 +282,34 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA
|
||||||
// TODO: This does not honour history_visibility
|
// TODO: This does not honour history_visibility
|
||||||
// TODO: This does not honour m.room.create content
|
// TODO: This does not honour m.room.create content
|
||||||
var queryMembershipRes roomserver.QueryMembershipForUserResponse
|
var queryMembershipRes roomserver.QueryMembershipForUserResponse
|
||||||
err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{
|
err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
|
||||||
RoomID: event.RoomID(),
|
RoomID: event.RoomID(),
|
||||||
UserID: userID,
|
UserID: rc.userID,
|
||||||
}, &queryMembershipRes)
|
}, &queryMembershipRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
|
util.GetLogger(rc.ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if !queryMembershipRes.IsInRoom {
|
if queryMembershipRes.IsInRoom {
|
||||||
util.GetLogger(ctx).Infof("user not in room")
|
return &event
|
||||||
|
}
|
||||||
|
if rc.req.AutoJoin {
|
||||||
|
var joinRes roomserver.PerformJoinResponse
|
||||||
|
rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{
|
||||||
|
UserID: rc.userID,
|
||||||
|
Content: map[string]interface{}{},
|
||||||
|
RoomIDOrAlias: event.RoomID(),
|
||||||
|
// TODO: Add server_names from linked room, currently this join will only work if the HS is already in the room
|
||||||
|
}, &joinRes)
|
||||||
|
if joinRes.Error != nil {
|
||||||
|
util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", event.RoomID()).Error("Failed to auto-join room")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &event
|
return &event
|
||||||
}
|
}
|
||||||
|
util.GetLogger(rc.ctx).Infof("user not in room and auto_join disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type walkInfo struct {
|
type walkInfo struct {
|
||||||
eventInfo
|
eventInfo
|
||||||
|
|
|
||||||
|
|
@ -353,6 +353,10 @@ func TestMSC2836(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: TestMSC2836TerminatesLoops (short and long)
|
||||||
|
// TODO: TestMSC2836UnknownEventsSkipped
|
||||||
|
// TODO: TestMSC2836SkipEventIfNotInRoom
|
||||||
|
|
||||||
func runServer(t *testing.T, router *mux.Router) func() {
|
func runServer(t *testing.T, router *mux.Router) func() {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
externalServ := &http.Server{
|
externalServ := &http.Server{
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue