Clean up changes

This commit is contained in:
Alex Chan 2021-07-29 03:07:38 -04:00
parent e50dbe61be
commit cd78100adb
9 changed files with 63 additions and 34 deletions

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"regexp"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -68,28 +69,43 @@ func GetMemberships(
_ *config.ClientAPI, _ *config.ClientAPI,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
// how would i unpack this nicely into
// type queryParams struct {
// membership string
// notMembership string
// }
membership := req.URL.Query().Get("membership") membership := req.URL.Query().Get("membership")
notMembership := req.URL.Query().Get("not_membership") notMembership := req.URL.Query().Get("not_membership")
membershipStatusFilter := []string{"join", "invite", "leave", "ban"}
if len(notMembership) > 0 { regexpMembershipFilter, _ := regexp.Compile("join|invite|leave|ban")
for i, v := range membershipStatusFilter { if membership != "" && !regexpMembershipFilter.MatchString(membership) {
if v == notMembership { return util.JSONResponse{
membershipStatusFilter = append(membershipStatusFilter[:i], membershipStatusFilter[i+1:]...) Code: http.StatusBadRequest,
} }
} }
} else if len(membership) > 0 { if notMembership != "" && !regexpMembershipFilter.MatchString(notMembership) {
membershipStatusFilter = []string{membership} return util.JSONResponse{
Code: http.StatusBadRequest,
}
}
membershipFilter := []string{"join", "invite", "leave", "ban"}
if notMembership != "" {
if membership != "" && membership != notMembership {
for idx, val := range membershipFilter {
if val == notMembership {
membershipFilter = append(membershipFilter[:idx], membershipFilter[idx+1:]...)
break
}
}
}
// If membership and not_membership are both specified and they are the same,
// then we do no filtering at all because they create an OR conditional
} else { } else {
membershipStatusFilter = []string{} if membership != "" {
membershipFilter = []string{membership}
} }
}
queryReq := api.QueryMembershipsForRoomRequest{ queryReq := api.QueryMembershipsForRoomRequest{
JoinedOnly: joinedOnly, JoinedOnly: joinedOnly,
MembershipStatusFilter: membershipStatusFilter, MembershipFilter: membershipFilter,
RoomID: roomID, RoomID: roomID,
Sender: device.UserID, Sender: device.UserID,
} }

View file

@ -149,7 +149,9 @@ type QueryMembershipForUserResponse struct {
type QueryMembershipsForRoomRequest struct { type QueryMembershipsForRoomRequest struct {
// If true, only returns the membership events of "join" membership // If true, only returns the membership events of "join" membership
JoinedOnly bool `json:"joined_only"` JoinedOnly bool `json:"joined_only"`
MembershipStatusFilter []string `json:"membership_status_filter"` // The kinds of membership to filter for - returns the membership
// events with the appropriate filter (ie. join, invite, leave, ban)
MembershipFilter []string `json:"membership_filter"`
// ID of the room to fetch memberships from // ID of the room to fetch memberships from
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
// Optional - ID of the user sending the request, for checking if the // Optional - ID of the user sending the request, for checking if the

View file

@ -59,7 +59,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, fmt.Errorf("unknown room %s", roomID) return false, fmt.Errorf("unknown room %s", roomID)
} }
eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, []string{"join"}, false) eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, []string{}, false)
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -256,7 +256,7 @@ func (r *Inputer) calculateAndSetState(
stateAtEvent.Overwrite = true stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID var joinEventNIDs []types.EventNID
// Request join memberships only for local users only. // Request join memberships only for local users only.
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, []string{"join"}, true); err == nil { if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, []string{}, true); err == nil {
// If we have no local users that are joined to the room then any state about // If we have no local users that are joined to the room then any state about
// the room that we have is quite possibly out of date. Therefore in that case // the room that we have is quite possibly out of date. Therefore in that case
// we should overwrite it rather than merge it. // we should overwrite it rather than merge it.

View file

@ -537,7 +537,7 @@ func joinEventsFromHistoryVisibility(
if err != nil { if err != nil {
return nil, err return nil, err
} }
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, []string{"join"}, false) joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, []string{}, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -251,7 +251,7 @@ func (r *Queryer) QueryMembershipsForRoom(
if request.Sender == "" { if request.Sender == "" {
var events []types.Event var events []types.Event
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.MembershipStatusFilter, false) eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.MembershipFilter, false)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
} }
@ -286,7 +286,7 @@ func (r *Queryer) QueryMembershipsForRoom(
var stateEntries []types.StateEntry var stateEntries []types.StateEntry
if stillInRoom { if stillInRoom {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.MembershipStatusFilter, false) eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.MembershipFilter, false)
if err != nil { if err != nil {
return err return err
} }
@ -328,7 +328,7 @@ func (r *Queryer) QueryServerJoinedToRoom(
} }
response.RoomExists = true response.RoomExists = true
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, []string{"join"}, false) eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, []string{}, false)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
} }

View file

@ -131,7 +131,7 @@ type Database interface {
// been members of a given room. Only lookup events of "join" membership if // been members of a given room. Only lookup events of "join" membership if
// joinOnly is set to true. // joinOnly is set to true.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, membershipStatusFilter []string, localOnly bool) ([]types.EventNID, error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, membershipFilter []string, localOnly bool) ([]types.EventNID, error)
// EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was
// not found. // not found.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.

View file

@ -343,13 +343,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
} }
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, membershipStatusFilter []string, localOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, membershipFilter []string, localOnly bool,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
if joinOnly {
return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, tables.MembershipStateJoin, localOnly,
)
}
if len(membershipFilter) > 0 {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
var filteredEventNIDs []types.EventNID var filteredEventNIDs []types.EventNID
var err error var err error
if len(membershipStatusFilter) > 0 { leaveOrBan := false
for _, filter := range membershipStatusFilter { for _, filter := range membershipFilter {
if filter == "join" { if filter == "join" {
filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership( filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, tables.MembershipStateJoin, localOnly, ctx, roomNID, tables.MembershipStateJoin, localOnly,
@ -358,7 +365,8 @@ func (d *Database) GetMembershipEventNIDsForRoom(
return eventNIDs, err return eventNIDs, err
} }
eventNIDs = append(eventNIDs, filteredEventNIDs...) eventNIDs = append(eventNIDs, filteredEventNIDs...)
} else if filter == "invite" { }
if filter == "invite" {
filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership( filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, tables.MembershipStateInvite, localOnly, ctx, roomNID, tables.MembershipStateInvite, localOnly,
) )
@ -366,7 +374,8 @@ func (d *Database) GetMembershipEventNIDsForRoom(
return eventNIDs, err return eventNIDs, err
} }
eventNIDs = append(eventNIDs, filteredEventNIDs...) eventNIDs = append(eventNIDs, filteredEventNIDs...)
} else { }
if (filter == "ban" || filter == "leave") && !leaveOrBan {
filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership( filteredEventNIDs, err = d.MembershipTable.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, tables.MembershipStateLeaveOrBan, localOnly, ctx, roomNID, tables.MembershipStateLeaveOrBan, localOnly,
) )
@ -374,6 +383,7 @@ func (d *Database) GetMembershipEventNIDsForRoom(
return eventNIDs, err return eventNIDs, err
} }
eventNIDs = append(eventNIDs, filteredEventNIDs...) eventNIDs = append(eventNIDs, filteredEventNIDs...)
leaveOrBan = true
} }
} }
return eventNIDs, err return eventNIDs, err

View file

@ -520,3 +520,4 @@ Inviting an AS-hosted user asks the AS server
Can generate a openid access_token that can be exchanged for information about a user Can generate a openid access_token that can be exchanged for information about a user
Invalid openid access tokens are rejected Invalid openid access tokens are rejected
Requests to userinfo without access tokens are rejected Requests to userinfo without access tokens are rejected
Can get rooms/{roomId}/members at a given point