Add pushrules tests (#3044)

partly takes care of https://github.com/matrix-org/dendrite/issues/2870
by making sure that rule IDs don't start with a dot.

Co-authored-by: kegsay <kegan@matrix.org>
This commit is contained in:
Till 2023-04-14 13:35:27 +02:00 committed by GitHub
parent ca63b414da
commit c45d8cd688
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 430 additions and 68 deletions

View file

@ -8,10 +8,12 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -1235,3 +1237,396 @@ func Test3PID(t *testing.T) {
} }
}) })
} }
func TestPushRules(t *testing.T) {
alice := test.NewUser(t)
// create the default push rules, used when validating responses
localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
defaultRules, err := json.Marshal(pushRuleSets)
assert.NoError(t, err)
ruleID1 := "myrule"
ruleID2 := "myrule2"
ruleID3 := "myrule3"
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cfg.ClientAPI.RateLimiting.Enabled = false
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
natsInstance := jetstream.NATSInstance{}
defer close()
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
}
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
testCases := []struct {
name string
request *http.Request
wantStatusCode int
validateFunc func(t *testing.T, respBody *bytes.Buffer) // used when updating rules, otherwise wantStatusCode should be enough
queryAttr map[string]string
}{
{
name: "can not get rules without trailing slash",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can get default rules",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, defaultRules, respBody.Bytes())
},
},
{
name: "can get rules by scope",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, gjson.GetBytes(defaultRules, "global").Raw, respBody.String())
},
},
{
name: "can not get invalid rules by scope",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get rules for invalid scope and kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/invalid/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get rules for invalid kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can get rules by scope and kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, gjson.GetBytes(defaultRules, "global.override").Raw, respBody.String())
},
},
{
name: "can get rules by scope and content kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, gjson.GetBytes(defaultRules, "global.content").Raw, respBody.String())
},
},
{
name: "can not get rules by scope and room kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/room/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get rules by scope and sender kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/sender/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can get rules by scope and underride kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/underride/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, gjson.GetBytes(defaultRules, "global.underride").Raw, respBody.String())
},
},
{
name: "can not get rules by scope, kind and ID for invalid scope",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/doesnotexist/.m.rule.master", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get rules by scope, kind and ID for invalid kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/doesnotexist/.m.rule.master", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can get rules by scope, kind and ID",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master", strings.NewReader("")),
wantStatusCode: http.StatusOK,
},
{
name: "can not get rules by scope, kind and ID for invalid ID",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.doesnotexist", strings.NewReader("")),
wantStatusCode: http.StatusNotFound,
},
{
name: "can not get status for invalid attribute",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/invalid", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get status for invalid kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get enabled status for invalid scope",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get enabled status for invalid rule",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/doesnotexist/enabled", strings.NewReader("")),
wantStatusCode: http.StatusNotFound,
},
{
name: "can get enabled rules by scope, kind and ID",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.False(t, gjson.GetBytes(respBody.Bytes(), "enabled").Bool(), "expected master rule to be disabled")
},
},
{
name: "can get actions scope, kind and ID",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
actions := gjson.GetBytes(respBody.Bytes(), "actions").Array()
// only a basic check
assert.Equal(t, 1, len(actions))
},
},
{
name: "can not set enabled status with invalid JSON",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not set attribute for invalid attribute",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/doesnotexist", strings.NewReader("{}")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not set attribute for invalid scope",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("{}")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not set attribute for invalid kind",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("{}")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not set attribute for invalid rule",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/invalid/enabled", strings.NewReader("{}")),
wantStatusCode: http.StatusNotFound,
},
{
name: "can set enabled status with valid JSON",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(`{"enabled":true}`)),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
assert.True(t, gjson.GetBytes(rec.Body.Bytes(), "enabled").Bool(), "expected master rule to be enabled: %s", rec.Body.String())
},
},
{
name: "can set actions with valid JSON",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(`{"actions":["dont_notify","notify"]}`)),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
assert.Equal(t, 2, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 2 actions %s", rec.Body.String())
},
},
{
name: "can not create new push rule with invalid JSON",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not create new push rule with invalid rule content",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("{}")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not create new push rule with invalid scope",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can create new push rule with valid rule content",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule/actions", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
assert.Equal(t, 1, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 1 action %s", rec.Body.String())
},
},
{
name: "can not create new push starting with a dot",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/.myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can create new push rule after existing",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
queryAttr: map[string]string{
"after": ruleID1,
},
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
rules := gjson.ParseBytes(rec.Body.Bytes())
for i, rule := range rules.Array() {
if rule.Get("rule_id").Str == ruleID1 && i != 0 {
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID1)
}
if rule.Get("rule_id").Str == ruleID2 && i != 1 {
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID2)
}
}
},
},
{
name: "can create new push rule before existing",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule3", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
queryAttr: map[string]string{
"before": ruleID1,
},
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
rules := gjson.ParseBytes(rec.Body.Bytes())
for i, rule := range rules.Array() {
if rule.Get("rule_id").Str == ruleID3 && i != 0 {
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID3)
}
if rule.Get("rule_id").Str == ruleID1 && i != 1 {
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID1)
}
if rule.Get("rule_id").Str == ruleID2 && i != 2 {
t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
}
}
},
},
{
name: "can modify existing push rule",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule2/actions", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
actions := gjson.GetBytes(rec.Body.Bytes(), "actions").Array()
// there should only be one action
assert.Equal(t, "dont_notify", actions[0].Str)
},
},
{
name: "can move existing push rule to the front",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
queryAttr: map[string]string{
"before": ruleID3,
},
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
rules := gjson.ParseBytes(rec.Body.Bytes())
for i, rule := range rules.Array() {
if rule.Get("rule_id").Str == ruleID2 && i != 0 {
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID2)
}
if rule.Get("rule_id").Str == ruleID3 && i != 1 {
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID3)
}
if rule.Get("rule_id").Str == ruleID1 && i != 2 {
t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
}
}
},
},
{
name: "can not delete push rule with invalid scope",
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/invalid/content/myrule2", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not delete push rule with invalid kind",
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/invalid/myrule2", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not delete push rule with non-existent rule",
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/doesnotexist", strings.NewReader("")),
wantStatusCode: http.StatusNotFound,
},
{
name: "can delete existing push rule",
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader("")),
wantStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rec := httptest.NewRecorder()
if tc.queryAttr != nil {
params := url.Values{}
for k, v := range tc.queryAttr {
params.Set(k, v)
}
tc.request = httptest.NewRequest(tc.request.Method, tc.request.URL.String()+"?"+params.Encode(), tc.request.Body)
}
tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, tc.request)
assert.Equal(t, tc.wantStatusCode, rec.Code, rec.Body.String())
if tc.validateFunc != nil {
tc.validateFunc(t, rec.Body)
}
t.Logf("%s", rec.Body.String())
})
}
})
}

View file

@ -31,7 +31,7 @@ func errorResponse(ctx context.Context, err error, msg string, args ...interface
} }
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed") return errorResponse(ctx, err, "queryPushRulesJSON failed")
} }
@ -42,7 +42,7 @@ func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userap
} }
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed") return errorResponse(ctx, err, "queryPushRulesJSON failed")
} }
@ -57,7 +57,7 @@ func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Devi
} }
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -66,7 +66,8 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
} }
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil { // Even if rulesPtr is not nil, there may not be any rules for this kind
if rulesPtr == nil || (rulesPtr != nil && len(*rulesPtr) == 0) {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
} }
return util.JSONResponse{ return util.JSONResponse{
@ -76,7 +77,7 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi
} }
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -101,7 +102,10 @@ func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
var newRule pushrules.Rule var newRule pushrules.Rule
if err := json.NewDecoder(body).Decode(&newRule); err != nil { if err := json.NewDecoder(body).Decode(&newRule); err != nil {
return errorResponse(ctx, err, "JSON Decode failed") return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
} }
newRule.RuleID = ruleID newRule.RuleID = ruleID
@ -110,7 +114,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs) return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
} }
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -120,6 +124,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
} }
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil { if rulesPtr == nil {
// while this should be impossible (ValidateRule would already return an error), better keep it around
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
} }
i := pushRuleIndexByID(*rulesPtr, ruleID) i := pushRuleIndexByID(*rulesPtr, ruleID)
@ -144,7 +149,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
} }
// Add new rule. // Add new rule.
i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID) i, err = findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed") return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
} }
@ -153,7 +158,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i) util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
} }
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
return errorResponse(ctx, err, "putPushRules failed") return errorResponse(ctx, err, "putPushRules failed")
} }
@ -161,7 +166,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
} }
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -180,7 +185,7 @@ func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, dev
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
return errorResponse(ctx, err, "putPushRules failed") return errorResponse(ctx, err, "putPushRules failed")
} }
@ -192,7 +197,7 @@ func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
if err != nil { if err != nil {
return errorResponse(ctx, err, "pushRuleAttrGetter failed") return errorResponse(ctx, err, "pushRuleAttrGetter failed")
} }
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -238,7 +243,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
return errorResponse(ctx, err, "pushRuleAttrSetter failed") return errorResponse(ctx, err, "pushRuleAttrSetter failed")
} }
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -258,7 +263,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) { if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
attrSet((*rulesPtr)[i], &newPartialRule) attrSet((*rulesPtr)[i], &newPartialRule)
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
return errorResponse(ctx, err, "putPushRules failed") return errorResponse(ctx, err, "putPushRules failed")
} }
} }
@ -266,28 +271,6 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
} }
func queryPushRules(ctx context.Context, userID string, userAPI userapi.ClientUserAPI) (*pushrules.AccountRuleSets, error) {
var res userapi.QueryPushRulesResponse
if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
return nil, err
}
return res.RuleSets, nil
}
func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.ClientUserAPI) error {
req := userapi.PerformPushRulesPutRequest{
UserID: userID,
RuleSets: ruleSets,
}
var res struct{}
if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
return err
}
return nil
}
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet { func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
switch scope { switch scope {
case pushrules.GlobalScope: case pushrules.GlobalScope:

View file

@ -10,6 +10,10 @@ import (
func ValidateRule(kind Kind, rule *Rule) []error { func ValidateRule(kind Kind, rule *Rule) []error {
var errs []error var errs []error
if len(rule.RuleID) > 0 && rule.RuleID[:1] == "." {
errs = append(errs, fmt.Errorf("invalid rule ID: rule can not start with a dot"))
}
if !validRuleIDRE.MatchString(rule.RuleID) { if !validRuleIDRE.MatchString(rule.RuleID) {
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID)) errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
} }

View file

@ -90,7 +90,7 @@ type ClientUserAPI interface {
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
@ -99,7 +99,7 @@ type ClientUserAPI interface {
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error PerformPushRulesPut(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
@ -555,19 +555,6 @@ const (
HTTPKind PusherKind = "http" HTTPKind PusherKind = "http"
) )
type PerformPushRulesPutRequest struct {
UserID string `json:"user_id"`
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
}
type QueryPushRulesRequest struct {
UserID string `json:"user_id"`
}
type QueryPushRulesResponse struct {
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
}
type QueryNotificationsRequest struct { type QueryNotificationsRequest struct {
Localpart string `json:"localpart"` // Required. Localpart string `json:"localpart"` // Required.
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required. ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.

View file

@ -26,6 +26,7 @@ import (
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -872,36 +873,28 @@ func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPusher
func (a *UserInternalAPI) PerformPushRulesPut( func (a *UserInternalAPI) PerformPushRulesPut(
ctx context.Context, ctx context.Context,
req *api.PerformPushRulesPutRequest, userID string,
_ *struct{}, ruleSets *pushrules.AccountRuleSets,
) error { ) error {
bs, err := json.Marshal(&req.RuleSets) bs, err := json.Marshal(ruleSets)
if err != nil { if err != nil {
return err return err
} }
userReq := api.InputAccountDataRequest{ userReq := api.InputAccountDataRequest{
UserID: req.UserID, UserID: userID,
DataType: pushRulesAccountDataType, DataType: pushRulesAccountDataType,
AccountData: json.RawMessage(bs), AccountData: json.RawMessage(bs),
} }
var userRes api.InputAccountDataResponse // empty var userRes api.InputAccountDataResponse // empty
if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { return a.InputAccountData(ctx, &userReq, &userRes)
return err
}
return nil
} }
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { func (a *UserInternalAPI) QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) {
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) return nil, fmt.Errorf("failed to split user ID %q for push rules", userID)
} }
pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain) return a.DB.QueryPushRules(ctx, localpart, domain)
if err != nil {
return fmt.Errorf("failed to query push rules: %w", err)
}
res.RuleSets = pushRules
return nil
} }
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) { func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) {