Merge branch 'main' into devon/event-refactor

This commit is contained in:
Devon Hudson 2023-04-14 08:29:23 -06:00
commit bd89392052
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
12 changed files with 464 additions and 103 deletions

View file

@ -8,10 +8,12 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -1235,3 +1237,396 @@ func Test3PID(t *testing.T) {
} }
}) })
} }
func TestPushRules(t *testing.T) {
alice := test.NewUser(t)
// create the default push rules, used when validating responses
localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
defaultRules, err := json.Marshal(pushRuleSets)
assert.NoError(t, err)
ruleID1 := "myrule"
ruleID2 := "myrule2"
ruleID3 := "myrule3"
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cfg.ClientAPI.RateLimiting.Enabled = false
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
natsInstance := jetstream.NATSInstance{}
defer close()
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
}
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
testCases := []struct {
name string
request *http.Request
wantStatusCode int
validateFunc func(t *testing.T, respBody *bytes.Buffer) // used when updating rules, otherwise wantStatusCode should be enough
queryAttr map[string]string
}{
{
name: "can not get rules without trailing slash",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can get default rules",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, defaultRules, respBody.Bytes())
},
},
{
name: "can get rules by scope",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, gjson.GetBytes(defaultRules, "global").Raw, respBody.String())
},
},
{
name: "can not get invalid rules by scope",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get rules for invalid scope and kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/invalid/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get rules for invalid kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can get rules by scope and kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, gjson.GetBytes(defaultRules, "global.override").Raw, respBody.String())
},
},
{
name: "can get rules by scope and content kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, gjson.GetBytes(defaultRules, "global.content").Raw, respBody.String())
},
},
{
name: "can not get rules by scope and room kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/room/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get rules by scope and sender kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/sender/", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can get rules by scope and underride kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/underride/", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.Equal(t, gjson.GetBytes(defaultRules, "global.underride").Raw, respBody.String())
},
},
{
name: "can not get rules by scope, kind and ID for invalid scope",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/doesnotexist/.m.rule.master", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get rules by scope, kind and ID for invalid kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/doesnotexist/.m.rule.master", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can get rules by scope, kind and ID",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master", strings.NewReader("")),
wantStatusCode: http.StatusOK,
},
{
name: "can not get rules by scope, kind and ID for invalid ID",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.doesnotexist", strings.NewReader("")),
wantStatusCode: http.StatusNotFound,
},
{
name: "can not get status for invalid attribute",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/invalid", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get status for invalid kind",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get enabled status for invalid scope",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not get enabled status for invalid rule",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/doesnotexist/enabled", strings.NewReader("")),
wantStatusCode: http.StatusNotFound,
},
{
name: "can get enabled rules by scope, kind and ID",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
assert.False(t, gjson.GetBytes(respBody.Bytes(), "enabled").Bool(), "expected master rule to be disabled")
},
},
{
name: "can get actions scope, kind and ID",
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader("")),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
actions := gjson.GetBytes(respBody.Bytes(), "actions").Array()
// only a basic check
assert.Equal(t, 1, len(actions))
},
},
{
name: "can not set enabled status with invalid JSON",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not set attribute for invalid attribute",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/doesnotexist", strings.NewReader("{}")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not set attribute for invalid scope",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("{}")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not set attribute for invalid kind",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("{}")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not set attribute for invalid rule",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/invalid/enabled", strings.NewReader("{}")),
wantStatusCode: http.StatusNotFound,
},
{
name: "can set enabled status with valid JSON",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(`{"enabled":true}`)),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
assert.True(t, gjson.GetBytes(rec.Body.Bytes(), "enabled").Bool(), "expected master rule to be enabled: %s", rec.Body.String())
},
},
{
name: "can set actions with valid JSON",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(`{"actions":["dont_notify","notify"]}`)),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
assert.Equal(t, 2, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 2 actions %s", rec.Body.String())
},
},
{
name: "can not create new push rule with invalid JSON",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not create new push rule with invalid rule content",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("{}")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not create new push rule with invalid scope",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can create new push rule with valid rule content",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule/actions", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
assert.Equal(t, 1, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 1 action %s", rec.Body.String())
},
},
{
name: "can not create new push starting with a dot",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/.myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can create new push rule after existing",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
queryAttr: map[string]string{
"after": ruleID1,
},
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
rules := gjson.ParseBytes(rec.Body.Bytes())
for i, rule := range rules.Array() {
if rule.Get("rule_id").Str == ruleID1 && i != 0 {
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID1)
}
if rule.Get("rule_id").Str == ruleID2 && i != 1 {
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID2)
}
}
},
},
{
name: "can create new push rule before existing",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule3", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
queryAttr: map[string]string{
"before": ruleID1,
},
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
rules := gjson.ParseBytes(rec.Body.Bytes())
for i, rule := range rules.Array() {
if rule.Get("rule_id").Str == ruleID3 && i != 0 {
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID3)
}
if rule.Get("rule_id").Str == ruleID1 && i != 1 {
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID1)
}
if rule.Get("rule_id").Str == ruleID2 && i != 2 {
t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
}
}
},
},
{
name: "can modify existing push rule",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule2/actions", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
actions := gjson.GetBytes(rec.Body.Bytes(), "actions").Array()
// there should only be one action
assert.Equal(t, "dont_notify", actions[0].Str)
},
},
{
name: "can move existing push rule to the front",
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
queryAttr: map[string]string{
"before": ruleID3,
},
wantStatusCode: http.StatusOK,
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
rules := gjson.ParseBytes(rec.Body.Bytes())
for i, rule := range rules.Array() {
if rule.Get("rule_id").Str == ruleID2 && i != 0 {
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID2)
}
if rule.Get("rule_id").Str == ruleID3 && i != 1 {
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID3)
}
if rule.Get("rule_id").Str == ruleID1 && i != 2 {
t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
}
}
},
},
{
name: "can not delete push rule with invalid scope",
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/invalid/content/myrule2", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not delete push rule with invalid kind",
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/invalid/myrule2", strings.NewReader("")),
wantStatusCode: http.StatusBadRequest,
},
{
name: "can not delete push rule with non-existent rule",
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/doesnotexist", strings.NewReader("")),
wantStatusCode: http.StatusNotFound,
},
{
name: "can delete existing push rule",
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader("")),
wantStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rec := httptest.NewRecorder()
if tc.queryAttr != nil {
params := url.Values{}
for k, v := range tc.queryAttr {
params.Set(k, v)
}
tc.request = httptest.NewRequest(tc.request.Method, tc.request.URL.String()+"?"+params.Encode(), tc.request.Body)
}
tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, tc.request)
assert.Equal(t, tc.wantStatusCode, rec.Code, rec.Body.String())
if tc.validateFunc != nil {
tc.validateFunc(t, rec.Body)
}
t.Logf("%s", rec.Body.String())
})
}
})
}

View file

@ -31,7 +31,7 @@ func errorResponse(ctx context.Context, err error, msg string, args ...interface
} }
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed") return errorResponse(ctx, err, "queryPushRulesJSON failed")
} }
@ -42,7 +42,7 @@ func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userap
} }
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed") return errorResponse(ctx, err, "queryPushRulesJSON failed")
} }
@ -57,7 +57,7 @@ func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Devi
} }
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -66,7 +66,8 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
} }
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil { // Even if rulesPtr is not nil, there may not be any rules for this kind
if rulesPtr == nil || (rulesPtr != nil && len(*rulesPtr) == 0) {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
} }
return util.JSONResponse{ return util.JSONResponse{
@ -76,7 +77,7 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi
} }
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -101,7 +102,10 @@ func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
var newRule pushrules.Rule var newRule pushrules.Rule
if err := json.NewDecoder(body).Decode(&newRule); err != nil { if err := json.NewDecoder(body).Decode(&newRule); err != nil {
return errorResponse(ctx, err, "JSON Decode failed") return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
} }
newRule.RuleID = ruleID newRule.RuleID = ruleID
@ -110,7 +114,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs) return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
} }
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -120,6 +124,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
} }
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil { if rulesPtr == nil {
// while this should be impossible (ValidateRule would already return an error), better keep it around
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
} }
i := pushRuleIndexByID(*rulesPtr, ruleID) i := pushRuleIndexByID(*rulesPtr, ruleID)
@ -144,7 +149,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
} }
// Add new rule. // Add new rule.
i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID) i, err = findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed") return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
} }
@ -153,7 +158,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i) util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
} }
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
return errorResponse(ctx, err, "putPushRules failed") return errorResponse(ctx, err, "putPushRules failed")
} }
@ -161,7 +166,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
} }
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -180,7 +185,7 @@ func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, dev
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
return errorResponse(ctx, err, "putPushRules failed") return errorResponse(ctx, err, "putPushRules failed")
} }
@ -192,7 +197,7 @@ func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
if err != nil { if err != nil {
return errorResponse(ctx, err, "pushRuleAttrGetter failed") return errorResponse(ctx, err, "pushRuleAttrGetter failed")
} }
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -238,7 +243,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
return errorResponse(ctx, err, "pushRuleAttrSetter failed") return errorResponse(ctx, err, "pushRuleAttrSetter failed")
} }
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil { if err != nil {
return errorResponse(ctx, err, "queryPushRules failed") return errorResponse(ctx, err, "queryPushRules failed")
} }
@ -258,7 +263,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) { if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
attrSet((*rulesPtr)[i], &newPartialRule) attrSet((*rulesPtr)[i], &newPartialRule)
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
return errorResponse(ctx, err, "putPushRules failed") return errorResponse(ctx, err, "putPushRules failed")
} }
} }
@ -266,28 +271,6 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
} }
func queryPushRules(ctx context.Context, userID string, userAPI userapi.ClientUserAPI) (*pushrules.AccountRuleSets, error) {
var res userapi.QueryPushRulesResponse
if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
return nil, err
}
return res.RuleSets, nil
}
func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.ClientUserAPI) error {
req := userapi.PerformPushRulesPutRequest{
UserID: userID,
RuleSets: ruleSets,
}
var res struct{}
if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
return err
}
return nil
}
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet { func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
switch scope { switch scope {
case pushrules.GlobalScope: case pushrules.GlobalScope:

View file

@ -14,7 +14,7 @@ GEM
execjs execjs
coffee-script-source (1.11.1) coffee-script-source (1.11.1)
colorator (1.1.0) colorator (1.1.0)
commonmarker (0.23.7) commonmarker (0.23.9)
concurrent-ruby (1.2.0) concurrent-ruby (1.2.0)
dnsruby (1.61.9) dnsruby (1.61.9)
simpleidn (~> 0.1) simpleidn (~> 0.1)
@ -231,9 +231,9 @@ GEM
jekyll-seo-tag (~> 2.1) jekyll-seo-tag (~> 2.1)
minitest (5.17.0) minitest (5.17.0)
multipart-post (2.1.1) multipart-post (2.1.1)
nokogiri (1.13.10-arm64-darwin) nokogiri (1.14.3-arm64-darwin)
racc (~> 1.4) racc (~> 1.4)
nokogiri (1.13.10-x86_64-linux) nokogiri (1.14.3-x86_64-linux)
racc (~> 1.4) racc (~> 1.4)
octokit (4.22.0) octokit (4.22.0)
faraday (>= 0.9) faraday (>= 0.9)
@ -241,7 +241,7 @@ GEM
pathutil (0.16.2) pathutil (0.16.2)
forwardable-extended (~> 2.6) forwardable-extended (~> 2.6)
public_suffix (4.0.7) public_suffix (4.0.7)
racc (1.6.1) racc (1.6.2)
rb-fsevent (0.11.1) rb-fsevent (0.11.1)
rb-inotify (0.10.1) rb-inotify (0.10.1)
ffi (~> 1.0) ffi (~> 1.0)

View file

@ -256,10 +256,10 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// waste the effort. // waste the effort.
// TODO: Can we expand Check here to return a list of missing auth // TODO: Can we expand Check here to return a list of missing auth
// events rather than failing one at a time? // events rather than failing one at a time?
var respState *fclient.RespState var respState gomatrixserverlib.StateResponse
respState, err = respSendJoin.Check( respState, err = gomatrixserverlib.CheckSendJoinResponse(
context.Background(), context.Background(),
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion, &respSendJoin,
r.keyRing, r.keyRing,
event, event,
federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName), federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName),
@ -274,7 +274,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent // joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers") // to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
// The events are trusted now as we performed auth checks above. // The events are trusted now as we performed auth checks above.
joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false)) joinedHosts, err := consumers.JoinedHostsFromEvents(respState.GetStateEvents().TrustedEvents(respMakeJoin.RoomVersion, false))
if err != nil { if err != nil {
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err) return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
} }
@ -469,11 +469,12 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// we have the peek state now so let's process regardless of whether upstream gives up // we have the peek state now so let's process regardless of whether upstream gives up
ctx = context.Background() ctx = context.Background()
respState := respPeek.ToRespState()
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName)) authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName),
)
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
} }
@ -497,7 +498,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
ctx, r.rsAPI, r.cfg.Matrix.ServerName, ctx, r.rsAPI, r.cfg.Matrix.ServerName,
roomserverAPI.KindNew, roomserverAPI.KindNew,
&respState, // use the authorized state from CheckStateResponse
&fclient.RespState{
StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(stateEvents),
AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(authEvents),
},
respPeek.LatestEvent.Headered(respPeek.RoomVersion), respPeek.LatestEvent.Headered(respPeek.RoomVersion),
serverName, serverName,
nil, nil,

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f github.com/matrix-org/gomatrixserverlib v0.0.0-20230414110520-bcb28309fbcf
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16

6
go.sum
View file

@ -325,6 +325,12 @@ github.com/matrix-org/gomatrixserverlib v0.0.0-20230320105331-4dd7ff2f0e3a h1:F6
github.com/matrix-org/gomatrixserverlib v0.0.0-20230320105331-4dd7ff2f0e3a/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230320105331-4dd7ff2f0e3a/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f h1:D7IgZA2DxBroqCTxo2uXEmjj8eCI1OzqqKRE4SAgmBU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f h1:D7IgZA2DxBroqCTxo2uXEmjj8eCI1OzqqKRE4SAgmBU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230406115722-88763372cbd9 h1:bnIeD+u46GkPzfBcH9Su4sUBRZ9jqHcKER8hmaNnGD8=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230406115722-88763372cbd9/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230406130834-e85c3bfeed05 h1:s4JL2PPTgB6pR1ouDwq6KpqmV5hUTAP+HB1ajlemHWA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230406130834-e85c3bfeed05/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230414110520-bcb28309fbcf h1:bfjLsvf1M7IFJuwmFJAsgUi4XaaHs07e6tub0tb3Fpw=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230414110520-bcb28309fbcf/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=

View file

@ -10,6 +10,10 @@ import (
func ValidateRule(kind Kind, rule *Rule) []error { func ValidateRule(kind Kind, rule *Rule) []error {
var errs []error var errs []error
if len(rule.RuleID) > 0 && rule.RuleID[:1] == "." {
errs = append(errs, fmt.Errorf("invalid rule ID: rule can not start with a dot"))
}
if !validRuleIDRE.MatchString(rule.RuleID) { if !validRuleIDRE.MatchString(rule.RuleID) {
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID)) errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
} }

View file

@ -50,10 +50,10 @@ func SendEvents(
func SendEventWithState( func SendEventWithState(
ctx context.Context, rsAPI InputRoomEventsAPI, ctx context.Context, rsAPI InputRoomEventsAPI,
virtualHost gomatrixserverlib.ServerName, kind Kind, virtualHost gomatrixserverlib.ServerName, kind Kind,
state *fclient.RespState, event *gomatrixserverlib.HeaderedEvent, state gomatrixserverlib.StateResponse, event *gomatrixserverlib.HeaderedEvent,
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
) error { ) error {
outliers := state.Events(event.RoomVersion) outliers := gomatrixserverlib.LineariseStateResponse(event.RoomVersion, state)
ires := make([]InputRoomEvent, 0, len(outliers)) ires := make([]InputRoomEvent, 0, len(outliers))
for _, outlier := range outliers { for _, outlier := range outliers {
if haveEventIDs[outlier.EventID()] { if haveEventIDs[outlier.EventID()] {
@ -66,7 +66,7 @@ func SendEventWithState(
}) })
} }
stateEvents := state.StateEvents.UntrustedEvents(event.RoomVersion) stateEvents := state.GetStateEvents().UntrustedEvents(event.RoomVersion)
stateEventIDs := make([]string, len(stateEvents)) stateEventIDs := make([]string, len(stateEvents))
for i := range stateEvents { for i := range stateEvents {
stateEventIDs[i] = stateEvents[i].EventID() stateEventIDs[i] = stateEvents[i].EventID()

View file

@ -641,31 +641,19 @@ func (t *missingStateReq) lookupMissingStateViaState(
if err != nil { if err != nil {
return nil, err return nil, err
} }
s := fclient.RespState{
// Check that the returned state is valid.
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(), StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(), AuthEvents: state.GetAuthEvents(),
} }, roomVersion, t.keys, nil)
// Check that the returned state is valid.
authEvents, stateEvents, err := s.Check(ctx, roomVersion, t.keys, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
parsedState := &parsedRespState{ return &parsedRespState{
AuthEvents: authEvents, AuthEvents: authEvents,
StateEvents: stateEvents, StateEvents: stateEvents,
} }, nil
// Cache the results of this state lookup and deduplicate anything we already
// have in the cache, freeing up memory.
// We load these as trusted as we called state.Check before which loaded them as untrusted.
for i, evJSON := range s.AuthEvents {
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
parsedState.AuthEvents[i] = t.cacheAndReturn(ev)
}
for i, evJSON := range s.StateEvents {
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
parsedState.StateEvents[i] = t.cacheAndReturn(ev)
}
return parsedState, nil
} }
func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (

View file

@ -654,11 +654,11 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
messageEvents = append(messageEvents, ev) messageEvents = append(messageEvents, ev)
} }
} }
respState := fclient.RespState{ respState := &fclient.RespState{
AuthEvents: res.AuthChain, AuthEvents: res.AuthChain,
StateEvents: stateEvents, StateEvents: stateEvents,
} }
eventsInOrder := respState.Events(rc.roomVersion) eventsInOrder := gomatrixserverlib.LineariseStateResponse(rc.roomVersion, respState)
// everything gets sent as an outlier because auth chain events may be disjoint from the DAG // everything gets sent as an outlier because auth chain events may be disjoint from the DAG
// as may the threaded events. // as may the threaded events.
var ires []roomserver.InputRoomEvent var ires []roomserver.InputRoomEvent

View file

@ -90,7 +90,7 @@ type ClientUserAPI interface {
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
@ -99,7 +99,7 @@ type ClientUserAPI interface {
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error PerformPushRulesPut(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
@ -555,19 +555,6 @@ const (
HTTPKind PusherKind = "http" HTTPKind PusherKind = "http"
) )
type PerformPushRulesPutRequest struct {
UserID string `json:"user_id"`
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
}
type QueryPushRulesRequest struct {
UserID string `json:"user_id"`
}
type QueryPushRulesResponse struct {
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
}
type QueryNotificationsRequest struct { type QueryNotificationsRequest struct {
Localpart string `json:"localpart"` // Required. Localpart string `json:"localpart"` // Required.
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required. ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.

View file

@ -26,6 +26,7 @@ import (
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -872,36 +873,28 @@ func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPusher
func (a *UserInternalAPI) PerformPushRulesPut( func (a *UserInternalAPI) PerformPushRulesPut(
ctx context.Context, ctx context.Context,
req *api.PerformPushRulesPutRequest, userID string,
_ *struct{}, ruleSets *pushrules.AccountRuleSets,
) error { ) error {
bs, err := json.Marshal(&req.RuleSets) bs, err := json.Marshal(ruleSets)
if err != nil { if err != nil {
return err return err
} }
userReq := api.InputAccountDataRequest{ userReq := api.InputAccountDataRequest{
UserID: req.UserID, UserID: userID,
DataType: pushRulesAccountDataType, DataType: pushRulesAccountDataType,
AccountData: json.RawMessage(bs), AccountData: json.RawMessage(bs),
} }
var userRes api.InputAccountDataResponse // empty var userRes api.InputAccountDataResponse // empty
if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { return a.InputAccountData(ctx, &userReq, &userRes)
return err
}
return nil
} }
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { func (a *UserInternalAPI) QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) {
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) return nil, fmt.Errorf("failed to split user ID %q for push rules", userID)
} }
pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain) return a.DB.QueryPushRules(ctx, localpart, domain)
if err != nil {
return fmt.Errorf("failed to query push rules: %w", err)
}
res.RuleSets = pushRules
return nil
} }
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) { func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) {