Refactor invite interfaces to be more ergonomic

This commit is contained in:
Devon Hudson 2023-05-23 11:03:32 -06:00
parent 586ee2c349
commit c753aa455a
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
2 changed files with 61 additions and 62 deletions

View file

@ -88,6 +88,8 @@ func InviteV2(
PrivateKey: cfg.Matrix.PrivateKey, PrivateKey: cfg.Matrix.PrivateKey,
Verifier: keys, Verifier: keys,
InviteQuerier: rsAPI, InviteQuerier: rsAPI,
MembershipQuerier: MembershipQuerier{roomserver: rsAPI},
GenerateStrippedState: rsAPI.GenerateInviteStrippedState,
InviteEvent: inviteReq.Event(), InviteEvent: inviteReq.Event(),
StrippedState: inviteReq.InviteRoomState(), StrippedState: inviteReq.InviteRoomState(),
} }
@ -171,6 +173,8 @@ func InviteV1(
PrivateKey: cfg.Matrix.PrivateKey, PrivateKey: cfg.Matrix.PrivateKey,
Verifier: keys, Verifier: keys,
InviteQuerier: rsAPI, InviteQuerier: rsAPI,
MembershipQuerier: MembershipQuerier{roomserver: rsAPI},
GenerateStrippedState: rsAPI.GenerateInviteStrippedState,
InviteEvent: event, InviteEvent: event,
StrippedState: strippedState, StrippedState: strippedState,
} }
@ -231,12 +235,8 @@ func handleInvite(input HandleInviteInput, rsAPI api.FederationRoomserverAPI) (g
// TODO: Migrate to GMSL // TODO: Migrate to GMSL
type InviteQuerier interface { type RoomQuerier interface {
IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error)
// TODO: get rid of fclient references if possible
GenerateInviteStrippedState(ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU) ([]fclient.InviteV2StrippedState, error)
// TODO: Change this api shape dramatically for gmsl
QueryMembershipForUser(ctx context.Context, req *api.QueryMembershipForUserRequest, res *api.QueryMembershipForUserResponse) error
} }
type HandleInviteInput struct { type HandleInviteInput struct {
@ -248,7 +248,10 @@ type HandleInviteInput struct {
KeyID gomatrixserverlib.KeyID KeyID gomatrixserverlib.KeyID
PrivateKey ed25519.PrivateKey PrivateKey ed25519.PrivateKey
Verifier gomatrixserverlib.JSONVerifier Verifier gomatrixserverlib.JSONVerifier
InviteQuerier InviteQuerier InviteQuerier RoomQuerier
MembershipQuerier MembershipQuerier
// TODO: get rid of fclient references if possible
GenerateStrippedState func(ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU) ([]fclient.InviteV2StrippedState, error)
InviteEvent gomatrixserverlib.PDU InviteEvent gomatrixserverlib.PDU
StrippedState []fclient.InviteV2StrippedState StrippedState []fclient.InviteV2StrippedState
@ -333,7 +336,7 @@ func processInvite(
StateKey: "", StateKey: "",
}) })
} }
if is, err := input.InviteQuerier.GenerateInviteStrippedState(input.Context, input.RoomID, stateWanted, inviteEvent); err == nil { if is, err := input.GenerateStrippedState(input.Context, input.RoomID, stateWanted, inviteEvent); err == nil {
inviteState = is inviteState = is
} else { } else {
util.GetLogger(input.Context).WithError(err).Error("failed querying known room") util.GetLogger(input.Context).WithError(err).Error("failed querying known room")
@ -365,17 +368,13 @@ func processInvite(
} }
if isKnownRoom { if isKnownRoom {
req := api.QueryMembershipForUserRequest{ membership, err := input.MembershipQuerier.CurrentMembership(input.Context, input.RoomID, input.InvitedUser)
RoomID: input.RoomID.String(),
UserID: input.InvitedUser.String(),
}
res := api.QueryMembershipForUserResponse{}
err = input.InviteQuerier.QueryMembershipForUser(input.Context, &req, &res)
if err != nil { if err != nil {
util.GetLogger(input.Context).WithError(err).Error("failed getting user membership") util.GetLogger(input.Context).WithError(err).Error("failed getting user membership")
return nil, spec.InternalServerError{} return nil, spec.InternalServerError{}
} }
isAlreadyJoined := (res.Membership == spec.Join) isAlreadyJoined := (membership == spec.Join)
if isAlreadyJoined { if isAlreadyJoined {
// If the user is joined to the room then that takes precedence over this // If the user is joined to the room then that takes precedence over this

View file

@ -50,24 +50,6 @@ func (r *Inviter) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, er
return (info != nil && !info.IsStub()), nil return (info != nil && !info.IsStub()), nil
} }
func (r *Inviter) generateInviteStrippedState(
ctx context.Context, roomID spec.RoomID, inviteEvent *types.HeaderedEvent, inviteState []fclient.InviteV2StrippedState,
) (*types.RoomInfo, []fclient.InviteV2StrippedState, error) {
info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil {
return nil, nil, fmt.Errorf("failed to load RoomInfo: %w", err)
}
strippedState := inviteState
if len(strippedState) == 0 && info != nil {
var is []fclient.InviteV2StrippedState
if is, err = buildInviteStrippedState(ctx, r.DB, info, inviteEvent); err == nil {
strippedState = is
}
}
return info, strippedState, nil
}
func (r *Inviter) GenerateInviteStrippedState( func (r *Inviter) GenerateInviteStrippedState(
ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU, ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU,
) ([]fclient.InviteV2StrippedState, error) { ) ([]fclient.InviteV2StrippedState, error) {
@ -286,6 +268,24 @@ func (r *Inviter) PerformInvite(
return outputUpdates, nil return outputUpdates, nil
} }
func (r *Inviter) generateInviteStrippedState(
ctx context.Context, roomID spec.RoomID, inviteEvent *types.HeaderedEvent, inviteState []fclient.InviteV2StrippedState,
) (*types.RoomInfo, []fclient.InviteV2StrippedState, error) {
info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil {
return nil, nil, fmt.Errorf("failed to load RoomInfo: %w", err)
}
strippedState := inviteState
if len(strippedState) == 0 && info != nil {
var is []fclient.InviteV2StrippedState
if is, err = buildInviteStrippedState(ctx, r.DB, info, inviteEvent); err == nil {
strippedState = is
}
}
return info, strippedState, nil
}
func buildInviteStrippedState( func buildInviteStrippedState(
ctx context.Context, ctx context.Context,
db storage.Database, db storage.Database,