diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 7b59e3704..9d723bed1 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -88,7 +88,16 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) } events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} - events = append(events, output.NewRoomEvent.AddStateEvents...) + if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { + eventsReq := &api.QueryEventsByIDRequest{ + EventIDs: output.NewRoomEvent.AddsStateEventIDs, + } + eventsRes := &api.QueryEventsByIDResponse{} + if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { + return false + } + events = append(events, eventsRes.Events...) + } // Send event to any relevant application services if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go index 1965fe3f8..1c2278392 100644 --- a/clientapi/jsonerror/jsonerror.go +++ b/clientapi/jsonerror/jsonerror.go @@ -65,6 +65,11 @@ func BadJSON(msg string) *MatrixError { return &MatrixError{"M_BAD_JSON", msg} } +// BadAlias is an error when the client supplies a bad alias. +func BadAlias(msg string) *MatrixError { + return &MatrixError{"M_BAD_ALIAS", msg} +} + // NotJSON is an error when the client supplies something that is not JSON // to a JSON endpoint. func NotJSON(msg string) *MatrixError { diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index e408c264f..ac355b5d4 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -139,11 +139,17 @@ func SetLocalAlias( // TODO: This code should eventually be refactored with: // 1. The new method for checking for things matching an AS's namespace // 2. Using an overall Regex object for all AS's just like we did for usernames - + reqUserID, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("User ID must be in the form '@localpart:domain'"), + } + } for _, appservice := range cfg.Derived.ApplicationServices { // Don't prevent AS from creating aliases in its own namespace // Note that Dendrite uses SenderLocalpart as UserID for AS users - if device.UserID != appservice.SenderLocalpart { + if reqUserID != appservice.SenderLocalpart { if aliasNamespaces, ok := appservice.NamespaceMap["aliases"]; ok { for _, namespace := range aliasNamespaces { if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) { diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 23935b5d9..3d5993718 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -16,6 +16,8 @@ package routing import ( "context" + "encoding/json" + "fmt" "net/http" "sync" "time" @@ -120,6 +122,40 @@ func SendEvent( } timeToGenerateEvent := time.Since(startedGeneratingEvent) + // validate that the aliases exists + if eventType == gomatrixserverlib.MRoomCanonicalAlias && stateKey != nil && *stateKey == "" { + aliasReq := api.AliasEvent{} + if err = json.Unmarshal(e.Content(), &aliasReq); err != nil { + return util.ErrorResponse(fmt.Errorf("unable to parse alias event: %w", err)) + } + if !aliasReq.Valid() { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidParam("Request contains invalid aliases."), + } + } + aliasRes := &api.GetAliasesForRoomIDResponse{} + if err = rsAPI.GetAliasesForRoomID(req.Context(), &api.GetAliasesForRoomIDRequest{RoomID: roomID}, aliasRes); err != nil { + return jsonerror.InternalServerError() + } + var found int + requestAliases := append(aliasReq.AltAliases, aliasReq.Alias) + for _, alias := range aliasRes.Aliases { + for _, altAlias := range requestAliases { + if altAlias == alias { + found++ + } + } + } + // check that we found at least the same amount of existing aliases as are in the request + if aliasReq.Alias != "" && found < len(requestAliases) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadAlias("No matching alias found."), + } + } + } + var txnAndSessionID *api.TransactionID if txnID != nil { txnAndSessionID = &api.TransactionID{ diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 173dcff01..989f7cf49 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -146,7 +146,28 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error { - addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(ore.AddsState())) + eventsRes := &api.QueryEventsByIDResponse{} + if len(ore.AddsStateEventIDs) > 0 { + eventsReq := &api.QueryEventsByIDRequest{ + EventIDs: ore.AddsStateEventIDs, + } + if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { + return fmt.Errorf("s.rsAPI.QueryEventsByID: %w", err) + } + + found := false + for _, event := range eventsRes.Events { + if event.EventID() == ore.Event.EventID() { + found = true + break + } + } + if !found { + eventsRes.Events = append(eventsRes.Events, ore.Event) + } + } + + addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(eventsRes.Events)) if err != nil { return err } diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index df69e5b4d..be37333b6 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -14,6 +14,8 @@ package api +import "regexp" + // SetRoomAliasRequest is a request to SetRoomAlias type SetRoomAliasRequest struct { // ID of the user setting the alias @@ -84,3 +86,20 @@ type RemoveRoomAliasResponse struct { // Did we remove it? Removed bool `json:"removed"` } + +type AliasEvent struct { + Alias string `json:"alias"` + AltAliases []string `json:"alt_aliases"` +} + +var validateAliasRegex = regexp.MustCompile("^#.*:.+$") + +func (a AliasEvent) Valid() bool { + for _, alias := range a.AltAliases { + if !validateAliasRegex.MatchString(alias) { + return false + } + } + return a.Alias == "" || validateAliasRegex.MatchString(a.Alias) +} + diff --git a/roomserver/api/alias_test.go b/roomserver/api/alias_test.go new file mode 100644 index 000000000..680493b7b --- /dev/null +++ b/roomserver/api/alias_test.go @@ -0,0 +1,62 @@ +package api + +import "testing" + +func TestAliasEvent_Valid(t *testing.T) { + type fields struct { + Alias string + AltAliases []string + } + tests := []struct { + name string + fields fields + want bool + }{ + { + name: "empty alias", + fields: fields{ + Alias: "", + }, + want: true, + }, + { + name: "empty alias, invalid alt aliases", + fields: fields{ + Alias: "", + AltAliases: []string{ "%not:valid.local"}, + }, + }, + { + name: "valid alias, invalid alt aliases", + fields: fields{ + Alias: "#valid:test.local", + AltAliases: []string{ "%not:valid.local"}, + }, + }, + { + name: "empty alias, invalid alt aliases", + fields: fields{ + Alias: "", + AltAliases: []string{ "%not:valid.local"}, + }, + }, + { + name: "invalid alias", + fields: fields{ + Alias: "%not:valid.local", + AltAliases: []string{ }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AliasEvent{ + Alias: tt.fields.Alias, + AltAliases: tt.fields.AltAliases, + } + if got := a.Valid(); got != tt.want { + t.Errorf("Valid() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/roomserver/api/output.go b/roomserver/api/output.go index d60d1cc86..767611ec4 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -105,7 +105,7 @@ type OutputNewRoomEvent struct { Event *gomatrixserverlib.HeaderedEvent `json:"event"` // Does the event completely rewrite the room state? If so, then AddsStateEventIDs // will contain the entire room state. - RewritesState bool `json:"rewrites_state"` + RewritesState bool `json:"rewrites_state,omitempty"` // The latest events in the room after this event. // This can be used to set the prev events for new events in the room. // This also can be used to get the full current state after this event. @@ -113,16 +113,9 @@ type OutputNewRoomEvent struct { // The state event IDs that were added to the state of the room by this event. // Together with RemovesStateEventIDs this allows the receiver to keep an up to date // view of the current state of the room. - AddsStateEventIDs []string `json:"adds_state_event_ids"` - // All extra newly added state events. This is only set if there are *extra* events - // other than `Event`. This can happen when forks get merged because state resolution - // may decide a bunch of state events on one branch are now valid, so they will be - // present in this list. This is useful when trying to maintain the current state of a room - // as to do so you need to include both these events and `Event`. - AddStateEvents []*gomatrixserverlib.HeaderedEvent `json:"adds_state_events"` - + AddsStateEventIDs []string `json:"adds_state_event_ids,omitempty"` // The state event IDs that were removed from the state of the room by this event. - RemovesStateEventIDs []string `json:"removes_state_event_ids"` + RemovesStateEventIDs []string `json:"removes_state_event_ids,omitempty"` // The ID of the event that was output before this event. // Or the empty string if this is the first event output for this room. // This is used by consumers to check if they can safely update their @@ -145,10 +138,10 @@ type OutputNewRoomEvent struct { // // The state is given as a delta against the current state because they are // usually either the same state, or differ by just a couple of events. - StateBeforeAddsEventIDs []string `json:"state_before_adds_event_ids"` + StateBeforeAddsEventIDs []string `json:"state_before_adds_event_ids,omitempty"` // The state event IDs that are part of the current state, but not part // of the state at the event. - StateBeforeRemovesEventIDs []string `json:"state_before_removes_event_ids"` + StateBeforeRemovesEventIDs []string `json:"state_before_removes_event_ids,omitempty"` // The server name to use to push this event to other servers. // Or empty if this event shouldn't be pushed to other servers. // @@ -167,27 +160,7 @@ type OutputNewRoomEvent struct { SendAsServer string `json:"send_as_server"` // The transaction ID of the send request if sent by a local user and one // was specified - TransactionID *TransactionID `json:"transaction_id"` -} - -// AddsState returns all added state events from this event. -// -// This function is needed because `AddStateEvents` will not include a copy of -// the original event to save space, so you cannot use that slice alone. -// Instead, use this function which will add the original event if it is present -// in `AddsStateEventIDs`. -func (ore *OutputNewRoomEvent) AddsState() []*gomatrixserverlib.HeaderedEvent { - includeOutputEvent := false - for _, id := range ore.AddsStateEventIDs { - if id == ore.Event.EventID() { - includeOutputEvent = true - break - } - } - if !includeOutputEvent { - return ore.AddStateEvents - } - return append(ore.AddStateEvents, ore.Event) + TransactionID *TransactionID `json:"transaction_id,omitempty"` } // An OutputOldRoomEvent is written when the roomserver receives an old event. diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 7995279d2..5c1c04f01 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -16,12 +16,18 @@ package internal import ( "context" + "database/sql" + "errors" "fmt" - - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" + "time" asAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API. @@ -183,6 +189,57 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( } } + ev, err := r.DB.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomCanonicalAlias, "") + if err != nil && err != sql.ErrNoRows { + return err + } else if ev != nil { + stateAlias := gjson.GetBytes(ev.Content(), "alias").Str + // the alias to remove is currently set as the canonical alias, remove it + if stateAlias == request.Alias { + res, err := sjson.DeleteBytes(ev.Content(), "alias") + if err != nil { + return err + } + + sender := request.UserID + if request.UserID != ev.Sender() { + sender = ev.Sender() + } + + builder := &gomatrixserverlib.EventBuilder{ + Sender: sender, + RoomID: ev.RoomID(), + Type: ev.Type(), + StateKey: ev.StateKey(), + Content: res, + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) + if err != nil { + return fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) + } + if len(eventsNeeded.Tuples()) == 0 { + return errors.New("expecting state tuples for event builder, got none") + } + + stateRes := &api.QueryLatestEventsAndStateResponse{} + if err := helpers.QueryLatestEventsAndState(ctx, r.DB, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil { + return err + } + + newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, stateRes) + if err != nil { + return err + } + + err = api.SendEvents(ctx, r.RSAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, r.ServerName, r.ServerName, nil, false) + if err != nil { + return err + } + + } + } + // Remove the alias from the database if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil { return err diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index f4a52031a..7e58ef9d0 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -365,6 +365,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) LastSentEventID: u.lastEventIDSent, LatestEventIDs: latestEventIDs, TransactionID: u.transactionID, + SendAsServer: u.sendAsServer, } eventIDMap, err := u.stateEventMap() @@ -384,51 +385,17 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID]) } - ore.SendAsServer = u.sendAsServer - - // include extra state events if they were added as nearly every downstream component will care about it - // and we'd rather not have them all hit QueryEventsByID at the same time! - if len(ore.AddsStateEventIDs) > 0 { - var err error - if ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs); err != nil { - return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) - } - } - return &api.OutputEvent{ Type: api.OutputTypeNewRoomEvent, NewRoomEvent: &ore, }, nil } -// extraEventsForIDs returns the full events for the event IDs given, but does not include the current event being -// updated. -func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - var extraEventIDs []string - for _, e := range eventIDs { - if e == u.event.EventID() { - continue - } - extraEventIDs = append(extraEventIDs, e) - } - if len(extraEventIDs) == 0 { - return nil, nil - } - extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs) - if err != nil { - return nil, err - } - var h []*gomatrixserverlib.HeaderedEvent - for _, e := range extraEvents { - h = append(h, e.Headered(roomVersion)) - } - return h, nil -} - // retrieve an event nid -> event ID map for all events that need updating func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) { - var stateEventNIDs []types.EventNID - var allStateEntries []types.StateEntry + cap := len(u.added) + len(u.removed) + len(u.stateBeforeEventRemoves) + len(u.stateBeforeEventAdds) + stateEventNIDs := make(types.EventNIDs, 0, cap) + allStateEntries := make([]types.StateEntry, 0, cap) allStateEntries = append(allStateEntries, u.added...) allStateEntries = append(allStateEntries, u.removed...) allStateEntries = append(allStateEntries, u.stateBeforeEventRemoves...) @@ -436,12 +403,6 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) for _, entry := range allStateEntries { stateEventNIDs = append(stateEventNIDs, entry.EventNID) } - stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] + stateEventNIDs = stateEventNIDs[:util.SortAndUnique(stateEventNIDs)] return u.updater.EventIDs(u.ctx, stateEventNIDs) } - -type eventNIDSorter []types.EventNID - -func (s eventNIDSorter) Len() int { return len(s) } -func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 544b5f0c3..d444272d2 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/getsentry/sentry-go" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) @@ -29,6 +30,7 @@ func JetStreamConsumer( name := durable + "Pull" sub, err := js.PullSubscribe(subj, name, opts...) if err != nil { + sentry.CaptureException(err) return fmt.Errorf("nats.SubscribeSync: %w", err) } go func() { @@ -55,6 +57,7 @@ func JetStreamConsumer( } } else { // Something else went wrong, so we'll panic. + sentry.CaptureException(err) logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) } } @@ -64,15 +67,18 @@ func JetStreamConsumer( msg := msgs[0] if err = msg.InProgress(); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) + sentry.CaptureException(err) continue } if f(ctx, msg) { if err = msg.Ack(); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Ack: %w", err)) + sentry.CaptureException(err) } } else { if err = msg.Nak(); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) + sentry.CaptureException(err) } } } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 7ab50c32e..7fb043366 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -199,13 +199,14 @@ func (w *walker) storePaginationCache(paginationToken string, cache paginationIn } type roomVisit struct { - roomID string - depth int - vias []string // vias to query this room by + roomID string + parentRoomID string + depth int + vias []string // vias to query this room by } func (w *walker) walk() util.JSONResponse { - if !w.authorised(w.rootRoomID) { + if authorised, _ := w.authorised(w.rootRoomID, ""); !authorised { if w.caller != nil { // CS API format return util.JSONResponse{ @@ -238,8 +239,9 @@ func (w *walker) walk() util.JSONResponse { w.paginationToken = tok // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms c.unvisited = append(c.unvisited, roomVisit{ - roomID: w.rootRoomID, - depth: 0, + roomID: w.rootRoomID, + parentRoomID: "", + depth: 0, }) } @@ -277,23 +279,8 @@ func (w *walker) walk() util.JSONResponse { // If we know about this room and the caller is authorised (joined/world_readable) then pull // events locally - if w.roomExists(rv.roomID) && w.authorised(rv.roomID) { - // Get all `m.space.child` state events for this room - events, err := w.childReferences(rv.roomID) - if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room") - continue - } - discoveredChildEvents = events - - pubRoom := w.publicRoomsChunk(rv.roomID) - - discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ - PublicRoom: *pubRoom, - RoomType: roomType, - ChildrenState: events, - }) - } else { + roomExists := w.roomExists(rv.roomID) + if !roomExists { // attempt to query this room over federation, as either we've never heard of it before // or we've left it and hence are not authorised (but info may be exposed regardless) fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias) @@ -312,6 +299,29 @@ func (w *walker) walk() util.JSONResponse { // as these children may be rooms we do know about. roomType = ConstCreateEventContentValueSpace } + } else if authorised, isJoinedOrInvited := w.authorised(rv.roomID, rv.parentRoomID); authorised { + // Get all `m.space.child` state events for this room + events, err := w.childReferences(rv.roomID) + if err != nil { + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room") + continue + } + discoveredChildEvents = events + + pubRoom := w.publicRoomsChunk(rv.roomID) + + discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ + PublicRoom: *pubRoom, + RoomType: roomType, + ChildrenState: events, + }) + // don't walk children if the user is not joined/invited to the space + if !isJoinedOrInvited { + continue + } + } else { + // room exists but user is not authorised + continue } // don't walk the children @@ -332,9 +342,10 @@ func (w *walker) walk() util.JSONResponse { ev := discoveredChildEvents[i] _ = json.Unmarshal(ev.Content, &spaceContent) unvisited = append(unvisited, roomVisit{ - roomID: ev.StateKey, - depth: rv.depth + 1, - vias: spaceContent.Via, + roomID: ev.StateKey, + parentRoomID: rv.roomID, + depth: rv.depth + 1, + vias: spaceContent.Via, }) } } @@ -465,25 +476,29 @@ func (w *walker) roomExists(roomID string) bool { } // authorised returns true iff the user is joined this room or the room is world_readable -func (w *walker) authorised(roomID string) bool { +func (w *walker) authorised(roomID, parentRoomID string) (authed, isJoinedOrInvited bool) { if w.caller != nil { - return w.authorisedUser(roomID) + return w.authorisedUser(roomID, parentRoomID) } - return w.authorisedServer(roomID) + return w.authorisedServer(roomID), false } // authorisedServer returns true iff the server is joined this room or the room is world_readable func (w *walker) authorisedServer(roomID string) bool { - // Check history visibility first + // Check history visibility / join rules first hisVisTuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: "", } + joinRuleTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomJoinRules, + StateKey: "", + } var queryRoomRes roomserver.QueryCurrentStateResponse err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{ RoomID: roomID, StateTuples: []gomatrixserverlib.StateKeyTuple{ - hisVisTuple, + hisVisTuple, joinRuleTuple, }, }, &queryRoomRes) if err != nil { @@ -497,29 +512,46 @@ func (w *walker) authorisedServer(roomID string) bool { return true } } - // check if server is joined to the room - var queryRes fs.QueryJoinedHostServerNamesInRoomResponse - err = w.fsAPI.QueryJoinedHostServerNamesInRoom(w.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: roomID, - }, &queryRes) - if err != nil { - util.GetLogger(w.ctx).WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom") - return false + + // check if this room is a restricted room and if so, we need to check if the server is joined to an allowed room ID + // in addition to the actual room ID (but always do the actual one first as it's quicker in the common case) + allowJoinedToRoomIDs := []string{roomID} + joinRuleEv := queryRoomRes.StateEvents[joinRuleTuple] + if joinRuleEv != nil { + allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...) } - for _, srv := range queryRes.ServerNames { - if srv == w.serverName { - return true + + // check if server is joined to any allowed room + for _, allowedRoomID := range allowJoinedToRoomIDs { + var queryRes fs.QueryJoinedHostServerNamesInRoomResponse + err = w.fsAPI.QueryJoinedHostServerNamesInRoom(w.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ + RoomID: allowedRoomID, + }, &queryRes) + if err != nil { + util.GetLogger(w.ctx).WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom") + continue + } + for _, srv := range queryRes.ServerNames { + if srv == w.serverName { + return true + } } } + return false } -// authorisedUser returns true iff the user is joined this room or the room is world_readable -func (w *walker) authorisedUser(roomID string) bool { +// authorisedUser returns true iff the user is invited/joined this room or the room is world_readable. +// Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true. +func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoinedOrInvited bool) { hisVisTuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: "", } + joinRuleTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomJoinRules, + StateKey: "", + } roomMemberTuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomMember, StateKey: w.caller.UserID, @@ -528,28 +560,79 @@ func (w *walker) authorisedUser(roomID string) bool { err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{ RoomID: roomID, StateTuples: []gomatrixserverlib.StateKeyTuple{ - hisVisTuple, roomMemberTuple, + hisVisTuple, joinRuleTuple, roomMemberTuple, }, }, &queryRes) if err != nil { util.GetLogger(w.ctx).WithError(err).Error("failed to QueryCurrentState") - return false + return false, false } memberEv := queryRes.StateEvents[roomMemberTuple] - hisVisEv := queryRes.StateEvents[hisVisTuple] if memberEv != nil { membership, _ := memberEv.Membership() if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { - return true + return true, true } } + hisVisEv := queryRes.StateEvents[hisVisTuple] if hisVisEv != nil { hisVis, _ := hisVisEv.HistoryVisibility() if hisVis == "world_readable" { - return true + return true, false } } - return false + joinRuleEv := queryRes.StateEvents[joinRuleTuple] + if parentRoomID != "" && joinRuleEv != nil { + allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership") + // check parent is in the allowed set + var allowed bool + for _, a := range allowedRoomIDs { + if parentRoomID == a { + allowed = true + break + } + } + if allowed { + // ensure caller is joined to the parent room + var queryRes2 roomserver.QueryCurrentStateResponse + err = w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{ + RoomID: parentRoomID, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + roomMemberTuple, + }, + }, &queryRes2) + if err != nil { + util.GetLogger(w.ctx).WithError(err).WithField("parent_room_id", parentRoomID).Warn("failed to check user is joined to parent room") + } else { + memberEv = queryRes2.StateEvents[roomMemberTuple] + if memberEv != nil { + membership, _ := memberEv.Membership() + if membership == gomatrixserverlib.Join { + return true, false + } + } + } + } + } + return false, false +} + +func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.HeaderedEvent, allowType string) (allows []string) { + rule, _ := joinRuleEv.JoinRule() + if rule != "restricted" { + return nil + } + var jrContent gomatrixserverlib.JoinRuleContent + if err := json.Unmarshal(joinRuleEv.Content(), &jrContent); err != nil { + util.GetLogger(w.ctx).Warnf("failed to check join_rule on room %s: %s", joinRuleEv.RoomID(), err) + return nil + } + for _, allow := range jrContent.Allow { + if allow.Type == allowType { + allows = append(allows, allow.RoomID) + } + } + return } // references returns all child references pointing to or from this room. diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 159657f9f..640c505c2 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -154,7 +154,42 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( ctx context.Context, msg api.OutputNewRoomEvent, ) error { ev := msg.Event - addsStateEvents := msg.AddsState() + + addsStateEvents := []*gomatrixserverlib.HeaderedEvent{} + foundEventIDs := map[string]bool{} + if len(msg.AddsStateEventIDs) > 0 { + for _, eventID := range msg.AddsStateEventIDs { + foundEventIDs[eventID] = false + } + foundEvents, err := s.db.Events(ctx, msg.AddsStateEventIDs) + if err != nil { + return fmt.Errorf("s.db.Events: %w", err) + } + for _, event := range foundEvents { + foundEventIDs[event.EventID()] = true + } + eventsReq := &api.QueryEventsByIDRequest{} + eventsRes := &api.QueryEventsByIDResponse{} + for eventID, found := range foundEventIDs { + if !found { + eventsReq.EventIDs = append(eventsReq.EventIDs, eventID) + } + } + if err = s.rsAPI.QueryEventsByID(ctx, eventsReq, eventsRes); err != nil { + return fmt.Errorf("s.rsAPI.QueryEventsByID: %w", err) + } + for _, event := range eventsRes.Events { + eventID := event.EventID() + foundEvents = append(foundEvents, event) + foundEventIDs[eventID] = true + } + for eventID, found := range foundEventIDs { + if !found { + return fmt.Errorf("event %s is missing", eventID) + } + } + addsStateEvents = foundEvents + } ev, err := s.updateStateEvent(ev) if err != nil { diff --git a/sytest-blacklist b/sytest-blacklist index becc500ec..cee2406e5 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id # Flakey Local device key changes appear in /keys/changes +/context/ with lazy_load_members filter works # we don't support groups Remove group category @@ -31,8 +32,11 @@ Remove group role # Flakey AS-ghosted users can use rooms themselves +/context/ with lazy_load_members filter works +AS-ghosted users can use rooms via AS +Events in rooms with AS-hosted room aliases are sent to AS server # Flakey, need additional investigation Messages that notify from another user increment notification_count Messages that highlight from another user increment unread highlight count -Notifications can be viewed with GET /notifications \ No newline at end of file +Notifications can be viewed with GET /notifications diff --git a/sytest-whitelist b/sytest-whitelist index 63d779bf1..6c4745b32 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -648,7 +648,6 @@ Device list doesn't change if remote server is down /context/ on joined room works /context/ on non world readable room does not work /context/ returns correct number of events -/context/ with lazy_load_members filter works GET /rooms/:room_id/messages lazy loads members correctly Can query remote device keys using POST after notification Device deletion propagates over federation @@ -659,4 +658,9 @@ registration accepts non-ascii passwords registration with inhibit_login inhibits login The operation must be consistent through an interactive authentication session Multiple calls to /sync should not cause 500 errors -/context/ with lazy_load_members filter works +Canonical alias can be set +Canonical alias can include alt_aliases +Can delete canonical alias +Multiple calls to /sync should not cause 500 errors +AS can make room aliases +Accesing an AS-hosted room alias asks the AS server