diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index 6ddcf1be3..129c28cfd 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -51,16 +51,47 @@ type databaseJoinedMember struct { AvatarURL string `json:"avatar_url"` } +// // MembershipFilter is an enum representing kinds of membership to a room +// type MembershipFilter string + +// const ( +// MembershipDefault MembershipFilter = "none" +// MembershipJoin MembershipFilter = "join" +// MembershipInvite MembershipFilter = "invite" +// MembershipLeave MembershipFilter = "leave" +// MembershipBan MembershipFilter = "ban" +// ) + // GetMemberships implements GET /rooms/{roomId}/members func GetMemberships( req *http.Request, device *userapi.Device, roomID string, joinedOnly bool, _ *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { + // how would i unpack this nicely into + // type queryParams struct { + // membership string + // notMembership string + // } + membership := req.URL.Query().Get("membership") + notMembership := req.URL.Query().Get("not_membership") + membershipStatusFilter := []string{"join", "invite", "leave", "ban"} + if len(notMembership) > 0 { + for i, v := range membershipStatusFilter { + if v == notMembership { + membershipStatusFilter = append(membershipStatusFilter[:i], membershipStatusFilter[i+1:]...) + } + } + } else if len(membership) > 0 { + membershipStatusFilter = []string{membership} + } else { + membershipStatusFilter = []string{} + } queryReq := api.QueryMembershipsForRoomRequest{ - JoinedOnly: joinedOnly, - RoomID: roomID, - Sender: device.UserID, + JoinedOnly: joinedOnly, + MembershipStatusFilter: membershipStatusFilter, + RoomID: roomID, + Sender: device.UserID, } var queryRes api.QueryMembershipsForRoomResponse if err := rsAPI.QueryMembershipsForRoom(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index af35f7e72..95df541e1 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -148,7 +148,8 @@ type QueryMembershipForUserResponse struct { // QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom type QueryMembershipsForRoomRequest struct { // If true, only returns the membership events of "join" membership - JoinedOnly bool `json:"joined_only"` + JoinedOnly bool `json:"joined_only"` + MembershipStatusFilter []string `json:"membership_status_filter"` // ID of the room to fetch memberships from RoomID string `json:"room_id"` // Optional - ID of the user sending the request, for checking if the diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index a829bffca..a01bb44b7 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -59,7 +59,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, fmt.Errorf("unknown room %s", roomID) } - eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, []string{"join"}, false) if err != nil { return false, err } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 2a558c483..b7116272a 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -256,7 +256,7 @@ func (r *Inputer) calculateAndSetState( stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, []string{"join"}, true); err == nil { // If we have no local users that are joined to the room then any state about // the room that we have is quite possibly out of date. Therefore in that case // we should overwrite it rather than merge it. diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index d9d720f26..33ec3c48e 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -537,7 +537,7 @@ func joinEventsFromHistoryVisibility( if err != nil { return nil, err } - joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, []string{"join"}, false) if err != nil { return nil, err } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 408f9766e..de59e6411 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -251,7 +251,7 @@ func (r *Queryer) QueryMembershipsForRoom( if request.Sender == "" { var events []types.Event var eventNIDs []types.EventNID - eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, false) + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.MembershipStatusFilter, false) if err != nil { return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } @@ -286,7 +286,7 @@ func (r *Queryer) QueryMembershipsForRoom( var stateEntries []types.StateEntry if stillInRoom { var eventNIDs []types.EventNID - eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, false) + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.MembershipStatusFilter, false) if err != nil { return err } @@ -328,7 +328,7 @@ func (r *Queryer) QueryServerJoinedToRoom( } response.RoomExists = true - eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, []string{"join"}, false) if err != nil { return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index d2b0e75c9..10da796c6 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -131,7 +131,7 @@ type Database interface { // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. // Returns an error if there was a problem talking to the database. - GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) + GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, membershipStatusFilter []string, localOnly bool) ([]types.EventNID, error) // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // not found. // Returns an error if the retrieval went wrong. diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 096d5d7a8..ab42a2457 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -343,12 +343,40 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req } func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, + ctx context.Context, roomNID types.RoomNID, membershipStatusFilter []string, localOnly bool, ) ([]types.EventNID, error) { - if joinOnly { - return d.MembershipTable.SelectMembershipsFromRoomAndMembership( - ctx, roomNID, tables.MembershipStateJoin, localOnly, - ) + var eventNIDs []types.EventNID + var filteredEventNIDs []types.EventNID + var err error + if len(membershipStatusFilter) > 0 { + for _, filter := range membershipStatusFilter { + if filter == "join" { + filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership( + ctx, roomNID, tables.MembershipStateJoin, localOnly, + ) + if err != nil { + return eventNIDs, err + } + eventNIDs = append(eventNIDs, filteredEventNIDs...) + } else if filter == "invite" { + filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership( + ctx, roomNID, tables.MembershipStateInvite, localOnly, + ) + if err != nil { + return eventNIDs, err + } + eventNIDs = append(eventNIDs, filteredEventNIDs...) + } else { + filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership( + ctx, roomNID, tables.MembershipStateLeaveOrBan, localOnly, + ) + if err != nil { + return eventNIDs, err + } + eventNIDs = append(eventNIDs, filteredEventNIDs...) + } + } + return eventNIDs, err } return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly)