From 8c2c55fbb0e92cab400dd7c603d6f1e8cebe094f Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 28 Nov 2022 13:32:38 +0100 Subject: [PATCH] Push rule tweaks --- internal/pushrules/condition.go | 2 +- internal/pushrules/default_content.go | 2 +- internal/pushrules/default_override.go | 31 ++++++++++---------- internal/pushrules/default_pushrules_test.go | 16 ++++++---- internal/pushrules/default_underride.go | 10 +++---- internal/pushrules/evaluate.go | 10 +++++-- internal/pushrules/evaluate_test.go | 5 ++-- internal/pushrules/pushrules.go | 14 +++++---- internal/pushrules/validate.go | 5 +++- 9 files changed, 57 insertions(+), 38 deletions(-) diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go index 2d9773c0f..c7b30da8e 100644 --- a/internal/pushrules/condition.go +++ b/internal/pushrules/condition.go @@ -14,7 +14,7 @@ type Condition struct { // Pattern indicates the value pattern that must match. Required // for EventMatchCondition. - Pattern string `json:"pattern,omitempty"` + Pattern *string `json:"pattern,omitempty"` // Is indicates the condition that must be fulfilled. Required for // RoomMemberCountCondition. diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go index ce78194c3..a055ba03c 100644 --- a/internal/pushrules/default_content.go +++ b/internal/pushrules/default_content.go @@ -15,7 +15,7 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule { RuleID: MRuleContainsUserName, Default: true, Enabled: true, - Pattern: localpart, + Pattern: &localpart, Actions: []*Action{ {Kind: NotifyAction}, { diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go index f8a916564..f97427b71 100644 --- a/internal/pushrules/default_override.go +++ b/internal/pushrules/default_override.go @@ -27,11 +27,10 @@ const ( var ( mRuleMasterDefinition = Rule{ - RuleID: MRuleMaster, - Default: true, - Enabled: false, - Conditions: []*Condition{}, - Actions: []*Action{{Kind: DontNotifyAction}}, + RuleID: MRuleMaster, + Default: true, + Enabled: false, + Actions: []*Action{{Kind: DontNotifyAction}}, } mRuleSuppressNoticesDefinition = Rule{ RuleID: MRuleSuppressNotices, @@ -41,7 +40,7 @@ var ( { Kind: EventMatchCondition, Key: "content.msgtype", - Pattern: "m.notice", + Pattern: pointer("m.notice"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -54,7 +53,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -85,12 +84,12 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.tombstone", + Pattern: pointer("m.room.tombstone"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: "", + Pattern: pointer(""), }, }, Actions: []*Action{ @@ -109,12 +108,12 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.server_acl", + Pattern: pointer("m.room.server_acl"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: "", + Pattern: pointer(""), }, }, Actions: []*Action{}, @@ -127,7 +126,7 @@ var ( { Kind: EventMatchCondition, Key: "content.body", - Pattern: "@room", + Pattern: pointer("@room"), }, { Kind: SenderNotificationPermissionCondition, @@ -150,7 +149,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.reaction", + Pattern: pointer("m.reaction"), }, }, Actions: []*Action{ @@ -168,17 +167,17 @@ func mRuleInviteForMeDefinition(userID string) *Rule { { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, { Kind: EventMatchCondition, Key: "content.membership", - Pattern: "invite", + Pattern: pointer("invite"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: userID, + Pattern: pointer(userID), }, }, Actions: []*Action{ diff --git a/internal/pushrules/default_pushrules_test.go b/internal/pushrules/default_pushrules_test.go index 362f90c08..dea829842 100644 --- a/internal/pushrules/default_pushrules_test.go +++ b/internal/pushrules/default_pushrules_test.go @@ -21,7 +21,7 @@ func TestDefaultRules(t *testing.T) { // Default override rules { name: ".m.rule.master", - inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"conditions":[],"actions":["dont_notify"]}`), + inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"actions":["dont_notify"]}`), want: mRuleMasterDefinition, }, { @@ -31,12 +31,12 @@ func TestDefaultRules(t *testing.T) { }, { name: ".m.rule.invite_for_me", - inputBytes: []byte(`{"rule_id":".m.rule.invite_for_me","default":true,"enabled":true,"conditions":[{"key":"type","kind":"event_match","pattern":"m.room.member"},{"key":"content.membership","kind":"event_match","pattern":"invite"},{"key":"state_key","kind":"event_match","pattern":"@test:localhost"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + inputBytes: []byte(`{"rule_id":".m.rule.invite_for_me","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"},{"kind":"event_match","key":"content.membership","pattern":"invite"},{"kind":"event_match","key":"state_key","pattern":"@test:localhost"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), want: *mRuleInviteForMeDefinition("@test:localhost"), }, { name: ".m.rule.member_event", - inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"key":"type","kind":"event_match","pattern":"m.room.member"}],"actions":["dont_notify"]}`), + inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"}],"actions":["dont_notify"]}`), want: mRuleMemberEventDefinition, }, { @@ -62,13 +62,13 @@ func TestDefaultRules(t *testing.T) { // Default content rules { name: ".m.rule.contains_user_name", - inputBytes: []byte(`{"rule_id":".m.rule.contains_user_name","default":true,"enabled":true,"pattern":"myLocalUser","actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}]}`), + inputBytes: []byte(`{"rule_id":".m.rule.contains_user_name","default":true,"enabled":true,"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}],"pattern":"myLocalUser"}`), want: *mRuleContainsUserNameDefinition("myLocalUser"), }, // default underride rules { name: ".m.rule.call", - inputBytes: []byte(`{"rule_id":".m.rule.call","default":true,"enabled":true,"conditions":[{"key":"type","kind":"event_match","pattern":"m.call.invite"}],"actions":["notify",{"set_tweak":"sound","value":"ring"}]}`), + inputBytes: []byte(`{"rule_id":".m.rule.call","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.call.invite"}],"actions":["notify",{"set_tweak":"sound","value":"ring"}]}`), want: mRuleCallDefinition, }, { @@ -96,9 +96,15 @@ func TestDefaultRules(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { r := Rule{} + // unmarshal predefined push rules err := json.Unmarshal(tc.inputBytes, &r) assert.NoError(t, err) assert.Equal(t, tc.want, r) + + // and reverse it to check we get the expected result + got, err := json.Marshal(r) + assert.NoError(t, err) + assert.Equal(t, string(got), string(tc.inputBytes)) }) } diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go index e63634527..118bfae59 100644 --- a/internal/pushrules/default_underride.go +++ b/internal/pushrules/default_underride.go @@ -25,7 +25,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.call.invite", + Pattern: pointer("m.call.invite"), }, }, Actions: []*Action{ @@ -49,7 +49,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ @@ -73,7 +73,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ @@ -93,7 +93,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ @@ -108,7 +108,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 4ff9939a6..fc8e0f174 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -104,7 +104,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu case ContentKind: // TODO: "These configure behaviour for (unencrypted) messages // that match certain patterns." - Does that mean "content.body"? - return patternMatches("content.body", rule.Pattern, event) + if rule.Pattern == nil { + return false, nil + } + return patternMatches("content.body", *rule.Pattern, event) case RoomKind: return rule.RuleID == event.RoomID(), nil @@ -120,7 +123,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { switch cond.Kind { case EventMatchCondition: - return patternMatches(cond.Key, cond.Pattern, event) + if cond.Pattern == nil { + return false, fmt.Errorf("missing condition pattern") + } + return patternMatches(cond.Key, *cond.Pattern, event) case ContainsDisplayNameCondition: return patternMatches("content.body", ec.UserDisplayName(), event) diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index c5d5abd2a..bc34bcab1 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -79,8 +79,8 @@ func TestRuleMatches(t *testing.T) { {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, - {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, - {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, @@ -110,6 +110,7 @@ func TestConditionMatches(t *testing.T) { }{ {"empty", Condition{}, `{}`, false}, {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, + {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content", Pattern: pointer("")}, `{"content":{}}`, true}, // Neither of these should match because `content` is not a full string match, // and `content.body` is not a string value. diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go index 3c778bf68..d79c7865d 100644 --- a/internal/pushrules/pushrules.go +++ b/internal/pushrules/pushrules.go @@ -36,18 +36,18 @@ type Rule struct { // around. Required. Enabled bool `json:"enabled"` - // Actions describe the desired outcome, should the rule - // match. Required. - Actions []*Action `json:"actions"` - // Conditions provide the rule's conditions for OverrideKind and // UnderrideKind. Not allowed for other kinds. Conditions []*Condition `json:"conditions,omitempty"` + // Actions describe the desired outcome, should the rule + // match. Required. + Actions []*Action `json:"actions"` + // Pattern is the body pattern to match for ContentKind. Required // for that kind. The interpretation is the same as that of // Condition.Pattern. - Pattern string `json:"pattern"` + Pattern *string `json:"pattern,omitempty"` } // Scope only has one valid value. See also AccountRuleSets. @@ -69,3 +69,7 @@ const ( SenderKind Kind = "sender" UnderrideKind Kind = "underride" ) + +func pointer(s string) *string { + return &s +} diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go index 5d260f0b9..f50c51bd7 100644 --- a/internal/pushrules/validate.go +++ b/internal/pushrules/validate.go @@ -34,7 +34,10 @@ func ValidateRule(kind Kind, rule *Rule) []error { } case ContentKind: - if rule.Pattern == "" { + if rule.Pattern == nil { + errs = append(errs, fmt.Errorf("missing content rule pattern")) + } + if rule.Pattern != nil && *rule.Pattern == "" { errs = append(errs, fmt.Errorf("missing content rule pattern")) }