diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index 129c28cfd..9647f7e83 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -17,6 +17,7 @@ package routing import ( "encoding/json" "net/http" + "regexp" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" @@ -68,30 +69,45 @@ func GetMemberships( _ *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { - // how would i unpack this nicely into - // type queryParams struct { - // membership string - // notMembership string - // } membership := req.URL.Query().Get("membership") notMembership := req.URL.Query().Get("not_membership") - membershipStatusFilter := []string{"join", "invite", "leave", "ban"} - if len(notMembership) > 0 { - for i, v := range membershipStatusFilter { - if v == notMembership { - membershipStatusFilter = append(membershipStatusFilter[:i], membershipStatusFilter[i+1:]...) + + regexpMembershipFilter, _ := regexp.Compile("join|invite|leave|ban") + if membership != "" && !regexpMembershipFilter.MatchString(membership) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + } + } + if notMembership != "" && !regexpMembershipFilter.MatchString(notMembership) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + } + } + + membershipFilter := []string{"join", "invite", "leave", "ban"} + + if notMembership != "" { + if membership != "" && membership != notMembership { + for idx, val := range membershipFilter { + if val == notMembership { + membershipFilter = append(membershipFilter[:idx], membershipFilter[idx+1:]...) + break + } } } - } else if len(membership) > 0 { - membershipStatusFilter = []string{membership} + // If membership and not_membership are both specified and they are the same, + // then we do no filtering at all because they create an OR conditional } else { - membershipStatusFilter = []string{} + if membership != "" { + membershipFilter = []string{membership} + } } + queryReq := api.QueryMembershipsForRoomRequest{ - JoinedOnly: joinedOnly, - MembershipStatusFilter: membershipStatusFilter, - RoomID: roomID, - Sender: device.UserID, + JoinedOnly: joinedOnly, + MembershipFilter: membershipFilter, + RoomID: roomID, + Sender: device.UserID, } var queryRes api.QueryMembershipsForRoomResponse if err := rsAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 95df541e1..378b7c9f1 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -148,8 +148,10 @@ type QueryMembershipForUserResponse struct { // QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom type QueryMembershipsForRoomRequest struct { // If true, only returns the membership events of "join" membership - JoinedOnly bool `json:"joined_only"` - MembershipStatusFilter []string `json:"membership_status_filter"` + JoinedOnly bool `json:"joined_only"` + // The kinds of membership to filter for - returns the membership + // events with the appropriate filter (ie. join, invite, leave, ban) + MembershipFilter []string `json:"membership_filter"` // ID of the room to fetch memberships from RoomID string `json:"room_id"` // Optional - ID of the user sending the request, for checking if the diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index a01bb44b7..72879eade 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -59,7 +59,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, fmt.Errorf("unknown room %s", roomID) } - eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, []string{"join"}, false) + eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, []string{}, false) if err != nil { return false, err } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index b7116272a..297fc8681 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -256,7 +256,7 @@ func (r *Inputer) calculateAndSetState( stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, []string{"join"}, true); err == nil { + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, []string{}, true); err == nil { // If we have no local users that are joined to the room then any state about // the room that we have is quite possibly out of date. Therefore in that case // we should overwrite it rather than merge it. diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 33ec3c48e..9fc67c0f3 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -537,7 +537,7 @@ func joinEventsFromHistoryVisibility( if err != nil { return nil, err } - joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, []string{"join"}, false) + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, []string{}, false) if err != nil { return nil, err } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index de59e6411..6e8479ff3 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -251,7 +251,7 @@ func (r *Queryer) QueryMembershipsForRoom( if request.Sender == "" { var events []types.Event var eventNIDs []types.EventNID - eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.MembershipStatusFilter, false) + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.MembershipFilter, false) if err != nil { return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } @@ -286,7 +286,7 @@ func (r *Queryer) QueryMembershipsForRoom( var stateEntries []types.StateEntry if stillInRoom { var eventNIDs []types.EventNID - eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.MembershipStatusFilter, false) + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.MembershipFilter, false) if err != nil { return err } @@ -328,7 +328,7 @@ func (r *Queryer) QueryServerJoinedToRoom( } response.RoomExists = true - eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, []string{"join"}, false) + eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, []string{}, false) if err != nil { return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 10da796c6..d38905ab3 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -131,7 +131,7 @@ type Database interface { // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. // Returns an error if there was a problem talking to the database. - GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, membershipStatusFilter []string, localOnly bool) ([]types.EventNID, error) + GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, membershipFilter []string, localOnly bool) ([]types.EventNID, error) // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // not found. // Returns an error if the retrieval went wrong. diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index ab42a2457..dbf81701d 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -343,13 +343,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req } func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, membershipStatusFilter []string, localOnly bool, + ctx context.Context, roomNID types.RoomNID, joinOnly bool, membershipFilter []string, localOnly bool, ) ([]types.EventNID, error) { - var eventNIDs []types.EventNID - var filteredEventNIDs []types.EventNID - var err error - if len(membershipStatusFilter) > 0 { - for _, filter := range membershipStatusFilter { + if joinOnly { + return d.MembershipTable.SelectMembershipsFromRoomAndMembership( + ctx, roomNID, tables.MembershipStateJoin, localOnly, + ) + } + + if len(membershipFilter) > 0 { + var eventNIDs []types.EventNID + var filteredEventNIDs []types.EventNID + var err error + leaveOrBan := false + for _, filter := range membershipFilter { if filter == "join" { filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership( ctx, roomNID, tables.MembershipStateJoin, localOnly, @@ -358,7 +365,8 @@ func (d *Database) GetMembershipEventNIDsForRoom( return eventNIDs, err } eventNIDs = append(eventNIDs, filteredEventNIDs...) - } else if filter == "invite" { + } + if filter == "invite" { filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership( ctx, roomNID, tables.MembershipStateInvite, localOnly, ) @@ -366,7 +374,8 @@ func (d *Database) GetMembershipEventNIDsForRoom( return eventNIDs, err } eventNIDs = append(eventNIDs, filteredEventNIDs...) - } else { + } + if (filter == "ban" || filter == "leave") && !leaveOrBan { filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership( ctx, roomNID, tables.MembershipStateLeaveOrBan, localOnly, ) @@ -374,6 +383,7 @@ func (d *Database) GetMembershipEventNIDsForRoom( return eventNIDs, err } eventNIDs = append(eventNIDs, filteredEventNIDs...) + leaveOrBan = true } } return eventNIDs, err diff --git a/sytest-whitelist b/sytest-whitelist index 8c4585716..fecd8857f 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -520,3 +520,4 @@ Inviting an AS-hosted user asks the AS server Can generate a openid access_token that can be exchanged for information about a user Invalid openid access tokens are rejected Requests to userinfo without access tokens are rejected +Can get rooms/{roomId}/members at a given point