diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/membership.go b/src/github.com/matrix-org/dendrite/roomserver/input/membership.go index 7651e7c3c..d693e40fe 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/membership.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/membership.go @@ -32,11 +32,11 @@ func updateMemberships( changes := membershipChanges(removed, added) var eventNIDs []types.EventNID for _, change := range changes { - if change.added.EventNID != 0 { - eventNIDs = append(eventNIDs, change.added.EventNID) + if change.addedEventNID != 0 { + eventNIDs = append(eventNIDs, change.addedEventNID) } - if change.removed.EventNID != 0 { - eventNIDs = append(eventNIDs, change.removed.EventNID) + if change.removedEventNID != 0 { + eventNIDs = append(eventNIDs, change.removedEventNID) } } @@ -53,20 +53,18 @@ func updateMemberships( for _, change := range changes { var ae *gomatrixserverlib.Event var re *gomatrixserverlib.Event - var targetUserNID types.EventStateKeyNID - if change.removed.EventNID != 0 { - ev, _ := eventMap(events).lookup(change.removed.EventNID) + targetUserNID := change.EventStateKeyNID + if change.removedEventNID != 0 { + ev, _ := eventMap(events).lookup(change.removedEventNID) if ev != nil { re = &ev.Event } - targetUserNID = change.removed.EventStateKeyNID } - if change.added.EventNID != 0 { - ev, _ := eventMap(events).lookup(change.added.EventNID) + if change.addedEventNID != 0 { + ev, _ := eventMap(events).lookup(change.addedEventNID) if ev != nil { ae = &ev.Event } - targetUserNID = change.added.EventStateKeyNID } if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil { return nil, err @@ -222,8 +220,7 @@ func membershipChanges(removed, added []types.StateEntry) []stateChange { changes := pairUpChanges(removed, added) var result []stateChange for _, c := range changes { - if c.added.EventTypeNID == types.MRoomMemberNID || - c.removed.EventTypeNID == types.MRoomMemberNID { + if c.EventTypeNID == types.MRoomMemberNID { result = append(result, c) } } @@ -231,8 +228,9 @@ func membershipChanges(removed, added []types.StateEntry) []stateChange { } type stateChange struct { - removed types.StateEntry - added types.StateEntry + types.StateKeyTuple + removedEventNID types.EventNID + addedEventNID types.EventNID } // pairUpChanges pairs up the state events added and removed for each type, @@ -245,26 +243,40 @@ func pairUpChanges(removed, added []types.StateEntry) []stateChange { switch { case ai == len(added): for _, s := range removed[ri:] { - result = append(result, stateChange{removed: s}) + result = append(result, stateChange{ + StateKeyTuple: s.StateKeyTuple, + removedEventNID: s.EventNID, + }) } return result case ri == len(removed): for _, s := range added[ai:] { - result = append(result, stateChange{added: s}) + result = append(result, stateChange{ + StateKeyTuple: s.StateKeyTuple, + addedEventNID: s.EventNID, + }) } return result case added[ai].StateKeyTuple == removed[ri].StateKeyTuple: result = append(result, stateChange{ - removed: removed[ri], - added: added[ai], + StateKeyTuple: added[ai].StateKeyTuple, + removedEventNID: removed[ri].EventNID, + addedEventNID: added[ai].EventNID, }) ai++ ri++ case added[ai].StateKeyTuple.LessThan(removed[ri].StateKeyTuple): - result = append(result, stateChange{added: added[ai]}) + result = append(result, stateChange{ + + StateKeyTuple: added[ai].StateKeyTuple, + addedEventNID: added[ai].EventNID, + }) ai++ default: - result = append(result, stateChange{removed: removed[ri]}) + result = append(result, stateChange{ + StateKeyTuple: removed[ai].StateKeyTuple, + removedEventNID: removed[ri].EventNID, + }) ri++ } }