mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-21 04:53:14 -06:00
Merge branch 'main' into devon/event-refactor
This commit is contained in:
commit
bd89392052
|
|
@ -8,10 +8,12 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
|
@ -1235,3 +1237,396 @@ func Test3PID(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPushRules(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
|
||||||
|
// create the default push rules, used when validating responses
|
||||||
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||||
|
defaultRules, err := json.Marshal(pushRuleSets)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ruleID1 := "myrule"
|
||||||
|
ruleID2 := "myrule2"
|
||||||
|
ruleID3 := "myrule3"
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
natsInstance := jetstream.NATSInstance{}
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
routers := httputil.NewRouters()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||||
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||||
|
|
||||||
|
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||||
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||||
|
|
||||||
|
accessTokens := map[*test.User]userDevice{
|
||||||
|
alice: {},
|
||||||
|
}
|
||||||
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
request *http.Request
|
||||||
|
wantStatusCode int
|
||||||
|
validateFunc func(t *testing.T, respBody *bytes.Buffer) // used when updating rules, otherwise wantStatusCode should be enough
|
||||||
|
queryAttr map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "can not get rules without trailing slash",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can get default rules",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
assert.Equal(t, defaultRules, respBody.Bytes())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can get rules by scope",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
assert.Equal(t, gjson.GetBytes(defaultRules, "global").Raw, respBody.String())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get invalid rules by scope",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get rules for invalid scope and kind",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/invalid/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get rules for invalid kind",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can get rules by scope and kind",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
assert.Equal(t, gjson.GetBytes(defaultRules, "global.override").Raw, respBody.String())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can get rules by scope and content kind",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
assert.Equal(t, gjson.GetBytes(defaultRules, "global.content").Raw, respBody.String())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get rules by scope and room kind",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/room/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get rules by scope and sender kind",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/sender/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can get rules by scope and underride kind",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/underride/", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
assert.Equal(t, gjson.GetBytes(defaultRules, "global.underride").Raw, respBody.String())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get rules by scope, kind and ID for invalid scope",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/doesnotexist/.m.rule.master", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get rules by scope, kind and ID for invalid kind",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/doesnotexist/.m.rule.master", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can get rules by scope, kind and ID",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get rules by scope, kind and ID for invalid ID",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.doesnotexist", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get status for invalid attribute",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/invalid", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get status for invalid kind",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get enabled status for invalid scope",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not get enabled status for invalid rule",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/doesnotexist/enabled", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can get enabled rules by scope, kind and ID",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
assert.False(t, gjson.GetBytes(respBody.Bytes(), "enabled").Bool(), "expected master rule to be disabled")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can get actions scope, kind and ID",
|
||||||
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
actions := gjson.GetBytes(respBody.Bytes(), "actions").Array()
|
||||||
|
// only a basic check
|
||||||
|
assert.Equal(t, 1, len(actions))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not set enabled status with invalid JSON",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not set attribute for invalid attribute",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/doesnotexist", strings.NewReader("{}")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not set attribute for invalid scope",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("{}")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not set attribute for invalid kind",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("{}")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not set attribute for invalid rule",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/invalid/enabled", strings.NewReader("{}")),
|
||||||
|
wantStatusCode: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can set enabled status with valid JSON",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(`{"enabled":true}`)),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(""))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
routers.Client.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
assert.True(t, gjson.GetBytes(rec.Body.Bytes(), "enabled").Bool(), "expected master rule to be enabled: %s", rec.Body.String())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can set actions with valid JSON",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(`{"actions":["dont_notify","notify"]}`)),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(""))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
routers.Client.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
assert.Equal(t, 2, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 2 actions %s", rec.Body.String())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not create new push rule with invalid JSON",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not create new push rule with invalid rule content",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("{}")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not create new push rule with invalid scope",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can create new push rule with valid rule content",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule/actions", strings.NewReader(""))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
routers.Client.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
assert.Equal(t, 1, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 1 action %s", rec.Body.String())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not create new push starting with a dot",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/.myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can create new push rule after existing",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
||||||
|
queryAttr: map[string]string{
|
||||||
|
"after": ruleID1,
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
routers.Client.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
rules := gjson.ParseBytes(rec.Body.Bytes())
|
||||||
|
for i, rule := range rules.Array() {
|
||||||
|
if rule.Get("rule_id").Str == ruleID1 && i != 0 {
|
||||||
|
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID1)
|
||||||
|
}
|
||||||
|
if rule.Get("rule_id").Str == ruleID2 && i != 1 {
|
||||||
|
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can create new push rule before existing",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule3", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
||||||
|
queryAttr: map[string]string{
|
||||||
|
"before": ruleID1,
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
routers.Client.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
rules := gjson.ParseBytes(rec.Body.Bytes())
|
||||||
|
for i, rule := range rules.Array() {
|
||||||
|
if rule.Get("rule_id").Str == ruleID3 && i != 0 {
|
||||||
|
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID3)
|
||||||
|
}
|
||||||
|
if rule.Get("rule_id").Str == ruleID1 && i != 1 {
|
||||||
|
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID1)
|
||||||
|
}
|
||||||
|
if rule.Get("rule_id").Str == ruleID2 && i != 2 {
|
||||||
|
t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can modify existing push rule",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule2/actions", strings.NewReader(""))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
routers.Client.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
actions := gjson.GetBytes(rec.Body.Bytes(), "actions").Array()
|
||||||
|
// there should only be one action
|
||||||
|
assert.Equal(t, "dont_notify", actions[0].Str)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can move existing push rule to the front",
|
||||||
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
|
||||||
|
queryAttr: map[string]string{
|
||||||
|
"before": ruleID3,
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
routers.Client.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
rules := gjson.ParseBytes(rec.Body.Bytes())
|
||||||
|
for i, rule := range rules.Array() {
|
||||||
|
if rule.Get("rule_id").Str == ruleID2 && i != 0 {
|
||||||
|
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID2)
|
||||||
|
}
|
||||||
|
if rule.Get("rule_id").Str == ruleID3 && i != 1 {
|
||||||
|
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID3)
|
||||||
|
}
|
||||||
|
if rule.Get("rule_id").Str == ruleID1 && i != 2 {
|
||||||
|
t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not delete push rule with invalid scope",
|
||||||
|
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/invalid/content/myrule2", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not delete push rule with invalid kind",
|
||||||
|
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/invalid/myrule2", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not delete push rule with non-existent rule",
|
||||||
|
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/doesnotexist", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can delete existing push rule",
|
||||||
|
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader("")),
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if tc.queryAttr != nil {
|
||||||
|
params := url.Values{}
|
||||||
|
for k, v := range tc.queryAttr {
|
||||||
|
params.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
tc.request = httptest.NewRequest(tc.request.Method, tc.request.URL.String()+"?"+params.Encode(), tc.request.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
|
||||||
|
routers.Client.ServeHTTP(rec, tc.request)
|
||||||
|
assert.Equal(t, tc.wantStatusCode, rec.Code, rec.Body.String())
|
||||||
|
if tc.validateFunc != nil {
|
||||||
|
tc.validateFunc(t, rec.Body)
|
||||||
|
}
|
||||||
|
t.Logf("%s", rec.Body.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ func errorResponse(ctx context.Context, err error, msg string, args ...interface
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
||||||
}
|
}
|
||||||
|
|
@ -42,7 +42,7 @@ func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userap
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
return errorResponse(ctx, err, "queryPushRulesJSON failed")
|
||||||
}
|
}
|
||||||
|
|
@ -57,7 +57,7 @@ func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Devi
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "queryPushRules failed")
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
}
|
}
|
||||||
|
|
@ -66,7 +66,8 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi
|
||||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
|
||||||
}
|
}
|
||||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||||
if rulesPtr == nil {
|
// Even if rulesPtr is not nil, there may not be any rules for this kind
|
||||||
|
if rulesPtr == nil || (rulesPtr != nil && len(*rulesPtr) == 0) {
|
||||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
@ -76,7 +77,7 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "queryPushRules failed")
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
}
|
}
|
||||||
|
|
@ -101,7 +102,10 @@ func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device
|
||||||
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
var newRule pushrules.Rule
|
var newRule pushrules.Rule
|
||||||
if err := json.NewDecoder(body).Decode(&newRule); err != nil {
|
if err := json.NewDecoder(body).Decode(&newRule); err != nil {
|
||||||
return errorResponse(ctx, err, "JSON Decode failed")
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(err.Error()),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
newRule.RuleID = ruleID
|
newRule.RuleID = ruleID
|
||||||
|
|
||||||
|
|
@ -110,7 +114,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
|
||||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "queryPushRules failed")
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
}
|
}
|
||||||
|
|
@ -120,6 +124,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
|
||||||
}
|
}
|
||||||
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
|
||||||
if rulesPtr == nil {
|
if rulesPtr == nil {
|
||||||
|
// while this should be impossible (ValidateRule would already return an error), better keep it around
|
||||||
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
|
||||||
}
|
}
|
||||||
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
i := pushRuleIndexByID(*rulesPtr, ruleID)
|
||||||
|
|
@ -144,7 +149,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new rule.
|
// Add new rule.
|
||||||
i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
|
i, err = findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
|
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
|
||||||
}
|
}
|
||||||
|
|
@ -153,7 +158,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
|
||||||
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
|
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
|
||||||
return errorResponse(ctx, err, "putPushRules failed")
|
return errorResponse(ctx, err, "putPushRules failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -161,7 +166,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "queryPushRules failed")
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
}
|
}
|
||||||
|
|
@ -180,7 +185,7 @@ func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, dev
|
||||||
|
|
||||||
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
|
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
|
||||||
|
|
||||||
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
|
||||||
return errorResponse(ctx, err, "putPushRules failed")
|
return errorResponse(ctx, err, "putPushRules failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -192,7 +197,7 @@ func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
|
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
|
||||||
}
|
}
|
||||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "queryPushRules failed")
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
}
|
}
|
||||||
|
|
@ -238,7 +243,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
|
||||||
return errorResponse(ctx, err, "pushRuleAttrSetter failed")
|
return errorResponse(ctx, err, "pushRuleAttrSetter failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
|
ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(ctx, err, "queryPushRules failed")
|
return errorResponse(ctx, err, "queryPushRules failed")
|
||||||
}
|
}
|
||||||
|
|
@ -258,7 +263,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
|
||||||
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
|
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
|
||||||
attrSet((*rulesPtr)[i], &newPartialRule)
|
attrSet((*rulesPtr)[i], &newPartialRule)
|
||||||
|
|
||||||
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
|
if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
|
||||||
return errorResponse(ctx, err, "putPushRules failed")
|
return errorResponse(ctx, err, "putPushRules failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -266,28 +271,6 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func queryPushRules(ctx context.Context, userID string, userAPI userapi.ClientUserAPI) (*pushrules.AccountRuleSets, error) {
|
|
||||||
var res userapi.QueryPushRulesResponse
|
|
||||||
if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
|
|
||||||
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return res.RuleSets, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.ClientUserAPI) error {
|
|
||||||
req := userapi.PerformPushRulesPutRequest{
|
|
||||||
UserID: userID,
|
|
||||||
RuleSets: ruleSets,
|
|
||||||
}
|
|
||||||
var res struct{}
|
|
||||||
if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
|
|
||||||
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
|
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
|
||||||
switch scope {
|
switch scope {
|
||||||
case pushrules.GlobalScope:
|
case pushrules.GlobalScope:
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ GEM
|
||||||
execjs
|
execjs
|
||||||
coffee-script-source (1.11.1)
|
coffee-script-source (1.11.1)
|
||||||
colorator (1.1.0)
|
colorator (1.1.0)
|
||||||
commonmarker (0.23.7)
|
commonmarker (0.23.9)
|
||||||
concurrent-ruby (1.2.0)
|
concurrent-ruby (1.2.0)
|
||||||
dnsruby (1.61.9)
|
dnsruby (1.61.9)
|
||||||
simpleidn (~> 0.1)
|
simpleidn (~> 0.1)
|
||||||
|
|
@ -231,9 +231,9 @@ GEM
|
||||||
jekyll-seo-tag (~> 2.1)
|
jekyll-seo-tag (~> 2.1)
|
||||||
minitest (5.17.0)
|
minitest (5.17.0)
|
||||||
multipart-post (2.1.1)
|
multipart-post (2.1.1)
|
||||||
nokogiri (1.13.10-arm64-darwin)
|
nokogiri (1.14.3-arm64-darwin)
|
||||||
racc (~> 1.4)
|
racc (~> 1.4)
|
||||||
nokogiri (1.13.10-x86_64-linux)
|
nokogiri (1.14.3-x86_64-linux)
|
||||||
racc (~> 1.4)
|
racc (~> 1.4)
|
||||||
octokit (4.22.0)
|
octokit (4.22.0)
|
||||||
faraday (>= 0.9)
|
faraday (>= 0.9)
|
||||||
|
|
@ -241,7 +241,7 @@ GEM
|
||||||
pathutil (0.16.2)
|
pathutil (0.16.2)
|
||||||
forwardable-extended (~> 2.6)
|
forwardable-extended (~> 2.6)
|
||||||
public_suffix (4.0.7)
|
public_suffix (4.0.7)
|
||||||
racc (1.6.1)
|
racc (1.6.2)
|
||||||
rb-fsevent (0.11.1)
|
rb-fsevent (0.11.1)
|
||||||
rb-inotify (0.10.1)
|
rb-inotify (0.10.1)
|
||||||
ffi (~> 1.0)
|
ffi (~> 1.0)
|
||||||
|
|
|
||||||
|
|
@ -256,10 +256,10 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
||||||
// waste the effort.
|
// waste the effort.
|
||||||
// TODO: Can we expand Check here to return a list of missing auth
|
// TODO: Can we expand Check here to return a list of missing auth
|
||||||
// events rather than failing one at a time?
|
// events rather than failing one at a time?
|
||||||
var respState *fclient.RespState
|
var respState gomatrixserverlib.StateResponse
|
||||||
respState, err = respSendJoin.Check(
|
respState, err = gomatrixserverlib.CheckSendJoinResponse(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
respMakeJoin.RoomVersion,
|
respMakeJoin.RoomVersion, &respSendJoin,
|
||||||
r.keyRing,
|
r.keyRing,
|
||||||
event,
|
event,
|
||||||
federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName),
|
federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName),
|
||||||
|
|
@ -274,7 +274,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
||||||
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
|
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
|
||||||
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
|
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
|
||||||
// The events are trusted now as we performed auth checks above.
|
// The events are trusted now as we performed auth checks above.
|
||||||
joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false))
|
joinedHosts, err := consumers.JoinedHostsFromEvents(respState.GetStateEvents().TrustedEvents(respMakeJoin.RoomVersion, false))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
|
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -469,11 +469,12 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
|
||||||
|
|
||||||
// we have the peek state now so let's process regardless of whether upstream gives up
|
// we have the peek state now so let's process regardless of whether upstream gives up
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
respState := respPeek.ToRespState()
|
|
||||||
|
|
||||||
// authenticate the state returned (check its auth events etc)
|
// authenticate the state returned (check its auth events etc)
|
||||||
// the equivalent of CheckSendJoinResponse()
|
// the equivalent of CheckSendJoinResponse()
|
||||||
authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName))
|
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
|
||||||
|
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error checking state returned from peeking: %w", err)
|
return fmt.Errorf("error checking state returned from peeking: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -497,7 +498,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
|
||||||
if err = roomserverAPI.SendEventWithState(
|
if err = roomserverAPI.SendEventWithState(
|
||||||
ctx, r.rsAPI, r.cfg.Matrix.ServerName,
|
ctx, r.rsAPI, r.cfg.Matrix.ServerName,
|
||||||
roomserverAPI.KindNew,
|
roomserverAPI.KindNew,
|
||||||
&respState,
|
// use the authorized state from CheckStateResponse
|
||||||
|
&fclient.RespState{
|
||||||
|
StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(stateEvents),
|
||||||
|
AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(authEvents),
|
||||||
|
},
|
||||||
respPeek.LatestEvent.Headered(respPeek.RoomVersion),
|
respPeek.LatestEvent.Headered(respPeek.RoomVersion),
|
||||||
serverName,
|
serverName,
|
||||||
nil,
|
nil,
|
||||||
|
|
|
||||||
2
go.mod
2
go.mod
|
|
@ -22,7 +22,7 @@ require (
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230414110520-bcb28309fbcf
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
||||||
github.com/mattn/go-sqlite3 v1.14.16
|
github.com/mattn/go-sqlite3 v1.14.16
|
||||||
|
|
|
||||||
6
go.sum
6
go.sum
|
|
@ -325,6 +325,12 @@ github.com/matrix-org/gomatrixserverlib v0.0.0-20230320105331-4dd7ff2f0e3a h1:F6
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230320105331-4dd7ff2f0e3a/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230320105331-4dd7ff2f0e3a/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f h1:D7IgZA2DxBroqCTxo2uXEmjj8eCI1OzqqKRE4SAgmBU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f h1:D7IgZA2DxBroqCTxo2uXEmjj8eCI1OzqqKRE4SAgmBU=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
|
||||||
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230406115722-88763372cbd9 h1:bnIeD+u46GkPzfBcH9Su4sUBRZ9jqHcKER8hmaNnGD8=
|
||||||
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230406115722-88763372cbd9/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
|
||||||
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230406130834-e85c3bfeed05 h1:s4JL2PPTgB6pR1ouDwq6KpqmV5hUTAP+HB1ajlemHWA=
|
||||||
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230406130834-e85c3bfeed05/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
|
||||||
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230414110520-bcb28309fbcf h1:bfjLsvf1M7IFJuwmFJAsgUi4XaaHs07e6tub0tb3Fpw=
|
||||||
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230414110520-bcb28309fbcf/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,10 @@ import (
|
||||||
func ValidateRule(kind Kind, rule *Rule) []error {
|
func ValidateRule(kind Kind, rule *Rule) []error {
|
||||||
var errs []error
|
var errs []error
|
||||||
|
|
||||||
|
if len(rule.RuleID) > 0 && rule.RuleID[:1] == "." {
|
||||||
|
errs = append(errs, fmt.Errorf("invalid rule ID: rule can not start with a dot"))
|
||||||
|
}
|
||||||
|
|
||||||
if !validRuleIDRE.MatchString(rule.RuleID) {
|
if !validRuleIDRE.MatchString(rule.RuleID) {
|
||||||
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
|
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -50,10 +50,10 @@ func SendEvents(
|
||||||
func SendEventWithState(
|
func SendEventWithState(
|
||||||
ctx context.Context, rsAPI InputRoomEventsAPI,
|
ctx context.Context, rsAPI InputRoomEventsAPI,
|
||||||
virtualHost gomatrixserverlib.ServerName, kind Kind,
|
virtualHost gomatrixserverlib.ServerName, kind Kind,
|
||||||
state *fclient.RespState, event *gomatrixserverlib.HeaderedEvent,
|
state gomatrixserverlib.StateResponse, event *gomatrixserverlib.HeaderedEvent,
|
||||||
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
|
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
|
||||||
) error {
|
) error {
|
||||||
outliers := state.Events(event.RoomVersion)
|
outliers := gomatrixserverlib.LineariseStateResponse(event.RoomVersion, state)
|
||||||
ires := make([]InputRoomEvent, 0, len(outliers))
|
ires := make([]InputRoomEvent, 0, len(outliers))
|
||||||
for _, outlier := range outliers {
|
for _, outlier := range outliers {
|
||||||
if haveEventIDs[outlier.EventID()] {
|
if haveEventIDs[outlier.EventID()] {
|
||||||
|
|
@ -66,7 +66,7 @@ func SendEventWithState(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEvents := state.StateEvents.UntrustedEvents(event.RoomVersion)
|
stateEvents := state.GetStateEvents().UntrustedEvents(event.RoomVersion)
|
||||||
stateEventIDs := make([]string, len(stateEvents))
|
stateEventIDs := make([]string, len(stateEvents))
|
||||||
for i := range stateEvents {
|
for i := range stateEvents {
|
||||||
stateEventIDs[i] = stateEvents[i].EventID()
|
stateEventIDs[i] = stateEvents[i].EventID()
|
||||||
|
|
|
||||||
|
|
@ -641,31 +641,19 @@ func (t *missingStateReq) lookupMissingStateViaState(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s := fclient.RespState{
|
|
||||||
|
// Check that the returned state is valid.
|
||||||
|
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
|
||||||
StateEvents: state.GetStateEvents(),
|
StateEvents: state.GetStateEvents(),
|
||||||
AuthEvents: state.GetAuthEvents(),
|
AuthEvents: state.GetAuthEvents(),
|
||||||
}
|
}, roomVersion, t.keys, nil)
|
||||||
// Check that the returned state is valid.
|
|
||||||
authEvents, stateEvents, err := s.Check(ctx, roomVersion, t.keys, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
parsedState := &parsedRespState{
|
return &parsedRespState{
|
||||||
AuthEvents: authEvents,
|
AuthEvents: authEvents,
|
||||||
StateEvents: stateEvents,
|
StateEvents: stateEvents,
|
||||||
}
|
}, nil
|
||||||
// Cache the results of this state lookup and deduplicate anything we already
|
|
||||||
// have in the cache, freeing up memory.
|
|
||||||
// We load these as trusted as we called state.Check before which loaded them as untrusted.
|
|
||||||
for i, evJSON := range s.AuthEvents {
|
|
||||||
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
|
|
||||||
parsedState.AuthEvents[i] = t.cacheAndReturn(ev)
|
|
||||||
}
|
|
||||||
for i, evJSON := range s.StateEvents {
|
|
||||||
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
|
|
||||||
parsedState.StateEvents[i] = t.cacheAndReturn(ev)
|
|
||||||
}
|
|
||||||
return parsedState, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
|
func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
|
||||||
|
|
|
||||||
|
|
@ -654,11 +654,11 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
|
||||||
messageEvents = append(messageEvents, ev)
|
messageEvents = append(messageEvents, ev)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
respState := fclient.RespState{
|
respState := &fclient.RespState{
|
||||||
AuthEvents: res.AuthChain,
|
AuthEvents: res.AuthChain,
|
||||||
StateEvents: stateEvents,
|
StateEvents: stateEvents,
|
||||||
}
|
}
|
||||||
eventsInOrder := respState.Events(rc.roomVersion)
|
eventsInOrder := gomatrixserverlib.LineariseStateResponse(rc.roomVersion, respState)
|
||||||
// everything gets sent as an outlier because auth chain events may be disjoint from the DAG
|
// everything gets sent as an outlier because auth chain events may be disjoint from the DAG
|
||||||
// as may the threaded events.
|
// as may the threaded events.
|
||||||
var ires []roomserver.InputRoomEvent
|
var ires []roomserver.InputRoomEvent
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ type ClientUserAPI interface {
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
|
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
|
||||||
QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error
|
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
|
||||||
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
|
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
|
||||||
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
||||||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||||
|
|
@ -99,7 +99,7 @@ type ClientUserAPI interface {
|
||||||
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
|
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
|
||||||
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
|
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
|
||||||
PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
|
PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
|
||||||
PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error
|
PerformPushRulesPut(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets) error
|
||||||
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
|
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
|
||||||
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
||||||
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
||||||
|
|
@ -555,19 +555,6 @@ const (
|
||||||
HTTPKind PusherKind = "http"
|
HTTPKind PusherKind = "http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PerformPushRulesPutRequest struct {
|
|
||||||
UserID string `json:"user_id"`
|
|
||||||
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueryPushRulesRequest struct {
|
|
||||||
UserID string `json:"user_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueryPushRulesResponse struct {
|
|
||||||
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueryNotificationsRequest struct {
|
type QueryNotificationsRequest struct {
|
||||||
Localpart string `json:"localpart"` // Required.
|
Localpart string `json:"localpart"` // Required.
|
||||||
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
|
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ import (
|
||||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
@ -872,36 +873,28 @@ func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPusher
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformPushRulesPut(
|
func (a *UserInternalAPI) PerformPushRulesPut(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformPushRulesPutRequest,
|
userID string,
|
||||||
_ *struct{},
|
ruleSets *pushrules.AccountRuleSets,
|
||||||
) error {
|
) error {
|
||||||
bs, err := json.Marshal(&req.RuleSets)
|
bs, err := json.Marshal(ruleSets)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
userReq := api.InputAccountDataRequest{
|
userReq := api.InputAccountDataRequest{
|
||||||
UserID: req.UserID,
|
UserID: userID,
|
||||||
DataType: pushRulesAccountDataType,
|
DataType: pushRulesAccountDataType,
|
||||||
AccountData: json.RawMessage(bs),
|
AccountData: json.RawMessage(bs),
|
||||||
}
|
}
|
||||||
var userRes api.InputAccountDataResponse // empty
|
var userRes api.InputAccountDataResponse // empty
|
||||||
if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil {
|
return a.InputAccountData(ctx, &userReq, &userRes)
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
|
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) {
|
||||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
|
return nil, fmt.Errorf("failed to split user ID %q for push rules", userID)
|
||||||
}
|
}
|
||||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
|
return a.DB.QueryPushRules(ctx, localpart, domain)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to query push rules: %w", err)
|
|
||||||
}
|
|
||||||
res.RuleSets = pushRules
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) {
|
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue