Refine SenderID/UserID usage

This commit is contained in:
Devon Hudson 2023-06-02 18:00:43 -06:00
parent 00e719f7e7
commit 8a3c01fa19
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
3 changed files with 14 additions and 3 deletions

View file

@ -68,6 +68,7 @@ type InputRoomEventsAPI interface {
type QuerySenderIDAPI interface { type QuerySenderIDAPI interface {
// Accepts either roomID or alias // Accepts either roomID or alias
QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error)
QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error)
} }
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.

View file

@ -97,11 +97,12 @@ func (r *Inviter) ProcessInviteMembership(
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var outputUpdates []api.OutputEvent var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater var updater *shared.MembershipUpdater
_, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey())
userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey())
if err != nil { if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain())
if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
} }

View file

@ -1035,7 +1035,16 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
return nil return nil
} }
func (r *Queryer) QuerySenderIDForRoom(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) {
// TODO: implement this properly with pseudoIDs // TODO: implement this properly with pseudoIDs
return userID.String(), nil return userID.String(), nil
} }
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) {
// TODO: implement this properly with pseudoIDs
userID, err := spec.NewUserID(senderID, true)
if err != nil {
return spec.UserID{}, err
}
return *userID, err
}