Fix statekey usage in roomserver/input_membership

This commit is contained in:
Devon Hudson 2023-06-08 14:15:19 -06:00
parent 5e9b7d714f
commit d2bbf9e315
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628

View file

@ -18,7 +18,6 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -72,7 +71,7 @@ func (r *Inputer) updateMemberships(
if change.addedEventNID != 0 { if change.addedEventNID != 0 {
ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID) ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID)
} }
if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { if updates, err = r.updateMembership(ctx, updater, targetUserNID, re, ae, updates); err != nil {
return nil, err return nil, err
} }
} }
@ -80,6 +79,7 @@ func (r *Inputer) updateMemberships(
} }
func (r *Inputer) updateMembership( func (r *Inputer) updateMembership(
ctx context.Context,
updater *shared.RoomUpdater, updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
remove, add *types.Event, remove, add *types.Event,
@ -97,7 +97,7 @@ func (r *Inputer) updateMembership(
var targetLocal bool var targetLocal bool
if add != nil { if add != nil {
targetLocal = r.isLocalTarget(add) targetLocal = r.isLocalTarget(ctx, add)
} }
mu, err := updater.MembershipUpdater(targetUserNID, targetLocal) mu, err := updater.MembershipUpdater(targetUserNID, targetLocal)
@ -136,11 +136,14 @@ func (r *Inputer) updateMembership(
} }
} }
func (r *Inputer) isLocalTarget(event *types.Event) bool { func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
isTargetLocalUser := false isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil { if statekey := event.StateKey(); statekey != nil {
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey) userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey))
isTargetLocalUser = domain == r.ServerName if err != nil || userID == nil {
return isTargetLocalUser
}
isTargetLocalUser = userID.Domain() == r.ServerName
} }
return isTargetLocalUser return isTargetLocalUser
} }