Merge branch 'main' into patch-1

This commit is contained in:
genofire 2023-05-02 12:58:38 +02:00 committed by GitHub
commit a492c98c54
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
141 changed files with 2641 additions and 1995 deletions

View file

@ -67,6 +67,8 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Install libolm
run: sudo apt-get install libolm-dev libolm3
- name: Install Go - name: Install Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
@ -101,6 +103,8 @@ jobs:
--health-retries 5 --health-retries 5
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Install libolm
run: sudo apt-get install libolm-dev libolm3
- name: Setup go - name: Setup go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
@ -232,6 +236,8 @@ jobs:
--health-retries 5 --health-retries 5
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Install libolm
run: sudo apt-get install libolm-dev libolm3
- name: Setup go - name: Setup go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:

View file

@ -30,6 +30,7 @@ import (
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
@ -104,7 +105,7 @@ func (s *OutputRoomEventConsumer) onMessage(
ctx context.Context, state *appserviceState, msgs []*nats.Msg, ctx context.Context, state *appserviceState, msgs []*nats.Msg,
) bool { ) bool {
log.WithField("appservice", state.ID).Tracef("Appservice worker received %d message(s) from roomserver", len(msgs)) log.WithField("appservice", state.ID).Tracef("Appservice worker received %d message(s) from roomserver", len(msgs))
events := make([]*gomatrixserverlib.HeaderedEvent, 0, len(msgs)) events := make([]*types.HeaderedEvent, 0, len(msgs))
for _, msg := range msgs { for _, msg := range msgs {
// Only handle events we care about // Only handle events we care about
receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType)) receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType))
@ -174,13 +175,13 @@ func (s *OutputRoomEventConsumer) onMessage(
// endpoint. It will block for the backoff period if necessary. // endpoint. It will block for the backoff period if necessary.
func (s *OutputRoomEventConsumer) sendEvents( func (s *OutputRoomEventConsumer) sendEvents(
ctx context.Context, state *appserviceState, ctx context.Context, state *appserviceState,
events []*gomatrixserverlib.HeaderedEvent, events []*types.HeaderedEvent,
txnID string, txnID string,
) error { ) error {
// Create the transaction body. // Create the transaction body.
transaction, err := json.Marshal( transaction, err := json.Marshal(
ApplicationServiceTransaction{ ApplicationServiceTransaction{
Events: synctypes.HeaderedToClientEvents(events, synctypes.FormatAll), Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll),
}, },
) )
if err != nil { if err != nil {
@ -231,7 +232,7 @@ func (s *appserviceState) backoffAndPause(err error) error {
// event falls within one of a given application service's namespaces. // event falls within one of a given application service's namespaces.
// //
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
switch { switch {
case appservice.URL == "": case appservice.URL == "":
return false return false
@ -269,7 +270,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
// appserviceJoinedAtEvent returns a boolean depending on whether a given // appserviceJoinedAtEvent returns a boolean depending on whether a given
// appservice has membership at the time a given event was created. // appservice has membership at the time a given event was created.
func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
// TODO: This is only checking the current room state, not the state at // TODO: This is only checking the current room state, not the state at
// the event in question. Pretty sure this is what Synapse does too, but // the event in question. Pretty sure this is what Synapse does too, but
// until we have a lighter way of checking the state before the event that // until we have a lighter way of checking the state before the event that

View file

@ -142,8 +142,8 @@ func TestPurgeRoom(t *testing.T) {
// this starts the JetStream consumers // this starts the JetStream consumers
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics)
federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true)
rsAPI.SetFederationAPI(nil, nil) rsAPI.SetFederationAPI(fsAPI, nil)
// Create the room // Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {

View file

@ -13,19 +13,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/routing" "github.com/matrix-org/dendrite/clientapi/routing"
"github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -37,6 +31,15 @@ import (
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
) )
type userDevice struct { type userDevice struct {
@ -1630,3 +1633,502 @@ func TestPushRules(t *testing.T) {
} }
}) })
} }
// Tests the `/keys` endpoints.
// Note that this only tests the happy path.
func TestKeys(t *testing.T) {
alice := test.NewUser(t)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cfg.ClientAPI.RateLimiting.Enabled = false
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
natsInstance := jetstream.NATSInstance{}
defer close()
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
}
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
// Start a TLSServer with our client mux
srv := httptest.NewTLSServer(routers.Client)
defer srv.Close()
cl, err := mautrix.NewClient(srv.URL, id.UserID(alice.ID), accessTokens[alice].accessToken)
if err != nil {
t.Fatal(err)
}
// Set the client so the self-signed certificate is trusted
cl.Client = srv.Client()
cl.DeviceID = id.DeviceID(accessTokens[alice].deviceID)
cs := crypto.NewMemoryStore(nil)
oc := crypto.NewOlmMachine(cl, nil, cs, dummyStore{})
if err = oc.Load(); err != nil {
t.Fatal(err)
}
// tests `/keys/upload`
if err = oc.ShareKeys(ctx, 0); err != nil {
t.Fatal(err)
}
// tests `/keys/device_signing/upload`
_, err = oc.GenerateAndUploadCrossSigningKeys(accessTokens[alice].password, "passphrase")
if err != nil {
t.Fatal(err)
}
// tests `/keys/query`
dev, err := oc.GetOrFetchDevice(ctx, id.UserID(alice.ID), id.DeviceID(accessTokens[alice].deviceID))
if err != nil {
t.Fatal(err)
}
// Validate that the keys returned from the server are what the client has stored
oi := oc.OwnIdentity()
if oi.SigningKey != dev.SigningKey {
t.Fatalf("expected signing key '%s', got '%s'", oi.SigningKey, dev.SigningKey)
}
if oi.IdentityKey != dev.IdentityKey {
t.Fatalf("expected identity '%s', got '%s'", oi.IdentityKey, dev.IdentityKey)
}
// tests `/keys/signatures/upload`
if err = oc.SignOwnMasterKey(); err != nil {
t.Fatal(err)
}
// tests `/keys/claim`
otks := make(map[string]map[string]string)
otks[alice.ID] = map[string]string{
accessTokens[alice].deviceID: string(id.KeyAlgorithmSignedCurve25519),
}
data, err := json.Marshal(claimKeysRequest{OneTimeKeys: otks})
if err != nil {
t.Fatal(err)
}
req, err := http.NewRequest(http.MethodPost, srv.URL+"/_matrix/client/v3/keys/claim", bytes.NewBuffer(data))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
resp, err := srv.Client().Do(req)
if err != nil {
t.Fatal(err)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if !gjson.GetBytes(respBody, "one_time_keys."+alice.ID+"."+string(dev.DeviceID)).Exists() {
t.Fatalf("expected one time keys for alice, but didn't find any: %s", string(respBody))
}
})
}
type claimKeysRequest struct {
// The keys to be claimed. A map from user ID, to a map from device ID to algorithm name.
OneTimeKeys map[string]map[string]string `json:"one_time_keys"`
}
type dummyStore struct{}
func (d dummyStore) IsEncrypted(roomID id.RoomID) bool {
return true
}
func (d dummyStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
return &event.EncryptionEventContent{}
}
func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID {
return []id.RoomID{}
}
func TestKeyBackup(t *testing.T) {
alice := test.NewUser(t)
handleResponseCode := func(t *testing.T, rec *httptest.ResponseRecorder, expectedCode int) {
t.Helper()
if rec.Code != expectedCode {
t.Fatalf("expected HTTP %d, but got %d: %s", expectedCode, rec.Code, rec.Body.String())
}
}
testCases := []struct {
name string
request func(t *testing.T) *http.Request
validate func(t *testing.T, rec *httptest.ResponseRecorder)
}{
{
name: "can not create backup with invalid JSON",
request: func(t *testing.T) *http.Request {
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"`) // missing closing braces
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusBadRequest)
},
},
{
name: "can not create backup with missing auth_data", // as this would result in MarshalJSON errors when querying again
request: func(t *testing.T) *http.Request {
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"}`)
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusBadRequest)
},
},
{
name: "can create backup",
request: func(t *testing.T) *http.Request {
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1","auth_data":{"data":"random"}}`)
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
wantVersion := "1"
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
}
},
},
{
name: "can not query backup for invalid version",
request: func(t *testing.T) *http.Request {
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/1337", nil)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusNotFound)
},
},
{
name: "can not query backup for invalid version string",
request: func(t *testing.T) *http.Request {
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/notanumber", nil)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusNotFound)
},
},
{
name: "can query backup",
request: func(t *testing.T) *http.Request {
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version", nil)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
wantVersion := "1"
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
}
},
},
{
name: "can query backup without returning rooms",
request: func(t *testing.T) *http.Request {
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{
"version": "1",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) > 0 {
t.Fatalf("expected no rooms in version, but got %#v", gotRooms)
}
},
},
{
name: "can query backup for invalid room",
request: func(t *testing.T) *http.Request {
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test", test.WithQueryParams(map[string]string{
"version": "1",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
if gotSessions := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotSessions) > 0 {
t.Fatalf("expected no sessions in version, but got %#v", gotSessions)
}
},
},
{
name: "can not query backup for invalid session",
request: func(t *testing.T) *http.Request {
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test/doesnotexist", test.WithQueryParams(map[string]string{
"version": "1",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusNotFound)
},
},
{
name: "can not update backup with missing version",
request: func(t *testing.T) *http.Request {
return test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys")
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusBadRequest)
},
},
{
name: "can not update backup with invalid data",
request: func(t *testing.T) *http.Request {
reqBody := test.WithJSONBody(t, "")
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
"version": "0",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusBadRequest)
},
},
{
name: "can not update backup with wrong version",
request: func(t *testing.T) *http.Request {
reqBody := test.WithJSONBody(t, map[string]interface{}{
"rooms": map[string]interface{}{
"!testroom:test": map[string]interface{}{
"sessions": map[string]uapi.KeyBackupSession{},
},
},
})
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
"version": "5",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusForbidden)
},
},
{
name: "can update backup with correct version",
request: func(t *testing.T) *http.Request {
reqBody := test.WithJSONBody(t, map[string]interface{}{
"rooms": map[string]interface{}{
"!testroom:test": map[string]interface{}{
"sessions": map[string]uapi.KeyBackupSession{
"dummySession": {
FirstMessageIndex: 1,
},
},
},
},
})
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
"version": "1",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
},
},
{
name: "can update backup with correct version for specific room",
request: func(t *testing.T) *http.Request {
reqBody := test.WithJSONBody(t, map[string]interface{}{
"sessions": map[string]uapi.KeyBackupSession{
"dummySession": {
FirstMessageIndex: 1,
IsVerified: true,
SessionData: json.RawMessage("{}"),
},
},
})
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test", reqBody, test.WithQueryParams(map[string]string{
"version": "1",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
t.Logf("%#v", rec.Body.String())
},
},
{
name: "can update backup with correct version for specific room and session",
request: func(t *testing.T) *http.Request {
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
FirstMessageIndex: 1,
SessionData: json.RawMessage("{}"),
IsVerified: true,
ForwardedCount: 0,
})
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", reqBody, test.WithQueryParams(map[string]string{
"version": "1",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
},
},
{
name: "can update backup by version",
request: func(t *testing.T) *http.Request {
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
FirstMessageIndex: 1,
SessionData: json.RawMessage("{}"),
IsVerified: true,
ForwardedCount: 0,
})
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/1", reqBody, test.WithQueryParams(map[string]string{"version": "1"}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
t.Logf("%#v", rec.Body.String())
},
},
{
name: "can not update backup by version for invalid version",
request: func(t *testing.T) *http.Request {
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
FirstMessageIndex: 1,
SessionData: json.RawMessage("{}"),
IsVerified: true,
ForwardedCount: 0,
})
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/2", reqBody, test.WithQueryParams(map[string]string{"version": "1"}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
},
},
{
name: "can query backup sessions",
request: func(t *testing.T) *http.Request {
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{
"version": "1",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) != 1 {
t.Fatalf("expected one room in response, but got %#v", rec.Body.String())
}
},
},
{
name: "can query backup sessions by room",
request: func(t *testing.T) *http.Request {
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test", test.WithQueryParams(map[string]string{
"version": "1",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotRooms) != 1 {
t.Fatalf("expected one session in response, but got %#v", rec.Body.String())
}
},
},
{
name: "can query backup sessions by room and sessionID",
request: func(t *testing.T) *http.Request {
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", test.WithQueryParams(map[string]string{
"version": "1",
}))
return req
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
if !gjson.GetBytes(rec.Body.Bytes(), "is_verified").Bool() {
t.Fatalf("expected session to be verified, but wasn't: %#v", rec.Body.String())
}
},
},
{
name: "can not delete invalid version backup",
request: func(t *testing.T) *http.Request {
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/2", nil)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusNotFound)
},
},
{
name: "can delete version backup",
request: func(t *testing.T) *http.Request {
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
},
},
{
name: "deleting the same backup version twice doesn't error",
request: func(t *testing.T) *http.Request {
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusOK)
},
},
{
name: "deleting an empty version doesn't work", // make sure we can't delete an empty backup version. Handled at the router level
request: func(t *testing.T) *http.Request {
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/", nil)
},
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
handleResponseCode(t, rec, http.StatusNotFound)
},
},
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cfg.ClientAPI.RateLimiting.Enabled = false
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
natsInstance := jetstream.NATSInstance{}
defer close()
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
}
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rec := httptest.NewRecorder()
req := tc.request(t)
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
routers.Client.ServeHTTP(rec, req)
tc.validate(t, rec)
})
}
})
}

View file

@ -171,6 +171,23 @@ func LeaveServerNoticeError() *MatrixError {
} }
} }
// ErrRoomKeysVersion is an error returned by `PUT /room_keys/keys`
type ErrRoomKeysVersion struct {
MatrixError
CurrentVersion string `json:"current_version"`
}
// WrongBackupVersionError is an error returned by `PUT /room_keys/keys`
func WrongBackupVersionError(currentVersion string) *ErrRoomKeysVersion {
return &ErrRoomKeysVersion{
MatrixError: MatrixError{
ErrCode: "M_WRONG_ROOM_KEYS_VERSION",
Err: "Wrong backup version.",
},
CurrentVersion: currentVersion,
}
}
type IncompatibleRoomVersionError struct { type IncompatibleRoomVersionError struct {
RoomVersion string `json:"room_version"` RoomVersion string `json:"room_version"`
Error string `json:"error"` Error string `json:"error"`

View file

@ -3,12 +3,14 @@ package routing
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -28,88 +30,60 @@ func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAP
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
res := &roomserverAPI.PerformAdminEvacuateRoomResponse{}
if err := rsAPI.PerformAdminEvacuateRoom( affected, err := rsAPI.PerformAdminEvacuateRoom(req.Context(), vars["roomID"])
req.Context(), switch err {
&roomserverAPI.PerformAdminEvacuateRoomRequest{ case nil:
RoomID: vars["roomID"], case eventutil.ErrRoomNoExists:
}, return util.JSONResponse{
res, Code: http.StatusNotFound,
); err != nil { JSON: jsonerror.NotFound(err.Error()),
}
default:
logrus.WithError(err).WithField("roomID", vars["roomID"]).Error("Failed to evacuate room")
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
if err := res.Error; err != nil {
return err.JSONResponse()
}
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: map[string]interface{}{ JSON: map[string]interface{}{
"affected": res.Affected, "affected": affected,
}, },
} }
} }
func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminEvacuateUser(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
userID := vars["userID"]
_, domain, err := gomatrixserverlib.SplitID('@', userID) affected, err := rsAPI.PerformAdminEvacuateUser(req.Context(), vars["userID"])
if err != nil { if err != nil {
logrus.WithError(err).WithField("userID", vars["userID"]).Error("Failed to evacuate user")
return util.MessageResponse(http.StatusBadRequest, err.Error()) return util.MessageResponse(http.StatusBadRequest, err.Error())
} }
if !cfg.Matrix.IsLocalServerName(domain) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("User ID must belong to this server."),
}
}
res := &roomserverAPI.PerformAdminEvacuateUserResponse{}
if err := rsAPI.PerformAdminEvacuateUser(
req.Context(),
&roomserverAPI.PerformAdminEvacuateUserRequest{
UserID: userID,
},
res,
); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := res.Error; err != nil {
return err.JSONResponse()
}
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: map[string]interface{}{ JSON: map[string]interface{}{
"affected": res.Affected, "affected": affected,
}, },
} }
} }
func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminPurgeRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
roomID := vars["roomID"]
res := &roomserverAPI.PerformAdminPurgeRoomResponse{} if err = rsAPI.PerformAdminPurgeRoom(context.Background(), vars["roomID"]); err != nil {
if err := rsAPI.PerformAdminPurgeRoom(
context.Background(),
&roomserverAPI.PerformAdminPurgeRoomRequest{
RoomID: roomID,
},
res,
); err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
if err := res.Error; err != nil {
return err.JSONResponse()
}
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: res, JSON: struct{}{},
} }
} }
@ -238,7 +212,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien
} }
} }
func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminDownloadState(req *http.Request, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -257,23 +231,22 @@ func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.De
JSON: jsonerror.MissingArgument("Expecting remote server name."), JSON: jsonerror.MissingArgument("Expecting remote server name."),
} }
} }
res := &roomserverAPI.PerformAdminDownloadStateResponse{} if err = rsAPI.PerformAdminDownloadState(req.Context(), roomID, device.UserID, spec.ServerName(serverName)); err != nil {
if err := rsAPI.PerformAdminDownloadState( if errors.Is(err, eventutil.ErrRoomNoExists) {
req.Context(), return util.JSONResponse{
&roomserverAPI.PerformAdminDownloadStateRequest{ Code: 200,
UserID: device.UserID, JSON: jsonerror.NotFound(eventutil.ErrRoomNoExists.Error()),
RoomID: roomID, }
ServerName: spec.ServerName(serverName), }
}, logrus.WithError(err).WithFields(logrus.Fields{
res, "userID": device.UserID,
); err != nil { "serverName": serverName,
return jsonerror.InternalAPIError(req.Context(), err) "roomID": roomID,
} }).Error("failed to download state")
if err := res.Error; err != nil { return util.ErrorResponse(err)
return err.JSONResponse()
} }
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: map[string]interface{}{}, JSON: struct{}{},
} }
} }

View file

@ -22,8 +22,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/getsentry/sentry-go"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
roomserverVersion "github.com/matrix-org/dendrite/roomserver/version" roomserverVersion "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
@ -48,7 +50,6 @@ type createRoomRequest struct {
CreationContent json.RawMessage `json:"creation_content"` CreationContent json.RawMessage `json:"creation_content"`
InitialState []fledglingEvent `json:"initial_state"` InitialState []fledglingEvent `json:"initial_state"`
RoomAliasName string `json:"room_alias_name"` RoomAliasName string `json:"room_alias_name"`
GuestCanJoin bool `json:"guest_can_join"`
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"` PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"`
IsDirect bool `json:"is_direct"` IsDirect bool `json:"is_direct"`
@ -253,16 +254,19 @@ func createRoom(
} }
} }
var guestsCanJoin bool
switch r.Preset { switch r.Preset {
case presetPrivateChat: case presetPrivateChat:
joinRuleContent.JoinRule = spec.Invite joinRuleContent.JoinRule = spec.Invite
historyVisibilityContent.HistoryVisibility = historyVisibilityShared historyVisibilityContent.HistoryVisibility = historyVisibilityShared
guestsCanJoin = true
case presetTrustedPrivateChat: case presetTrustedPrivateChat:
joinRuleContent.JoinRule = spec.Invite joinRuleContent.JoinRule = spec.Invite
historyVisibilityContent.HistoryVisibility = historyVisibilityShared historyVisibilityContent.HistoryVisibility = historyVisibilityShared
for _, invitee := range r.Invite { for _, invitee := range r.Invite {
powerLevelContent.Users[invitee] = 100 powerLevelContent.Users[invitee] = 100
} }
guestsCanJoin = true
case presetPublicChat: case presetPublicChat:
joinRuleContent.JoinRule = spec.Public joinRuleContent.JoinRule = spec.Public
historyVisibilityContent.HistoryVisibility = historyVisibilityShared historyVisibilityContent.HistoryVisibility = historyVisibilityShared
@ -317,7 +321,7 @@ func createRoom(
} }
} }
if r.GuestCanJoin { if guestsCanJoin {
guestAccessEvent = &fledglingEvent{ guestAccessEvent = &fledglingEvent{
Type: spec.MRoomGuestAccess, Type: spec.MRoomGuestAccess,
Content: eventutil.GuestAccessContent{ Content: eventutil.GuestAccessContent{
@ -429,7 +433,7 @@ func createRoom(
// TODO: invite events // TODO: invite events
// TODO: 3pid invite events // TODO: 3pid invite events
var builtEvents []*gomatrixserverlib.HeaderedEvent var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
for i, e := range eventsToMake { for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
@ -462,7 +466,7 @@ func createRoom(
} }
// Add the event to the list of auth events // Add the event to the list of auth events
builtEvents = append(builtEvents, ev.Headered(roomVersion)) builtEvents = append(builtEvents, &types.HeaderedEvent{Event: ev})
err = authEvents.AddEvent(ev) err = authEvents.AddEvent(ev)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed")
@ -541,9 +545,10 @@ func createRoom(
} }
// Process the invites. // Process the invites.
var inviteEvent *types.HeaderedEvent
for _, invitee := range r.Invite { for _, invitee := range r.Invite {
// Build the invite event. // Build the invite event.
inviteEvent, err := buildMembershipEvent( inviteEvent, err = buildMembershipEvent(
ctx, invitee, "", profileAPI, device, spec.Invite, ctx, invitee, "", profileAPI, device, spec.Invite,
roomID, r.IsDirect, cfg, evTime, rsAPI, asAPI, roomID, r.IsDirect, cfg, evTime, rsAPI, asAPI,
) )
@ -556,38 +561,44 @@ func createRoom(
fclient.NewInviteV2StrippedState(inviteEvent.Event), fclient.NewInviteV2StrippedState(inviteEvent.Event),
) )
// Send the invite event to the roomserver. // Send the invite event to the roomserver.
var inviteRes roomserverAPI.PerformInviteResponse event := inviteEvent
event := inviteEvent.Headered(roomVersion) err = rsAPI.PerformInvite(ctx, &roomserverAPI.PerformInviteRequest{
if err := rsAPI.PerformInvite(ctx, &roomserverAPI.PerformInviteRequest{
Event: event, Event: event,
InviteRoomState: inviteStrippedState, InviteRoomState: inviteStrippedState,
RoomVersion: event.RoomVersion, RoomVersion: event.Version(),
SendAsServer: string(userDomain), SendAsServer: string(userDomain),
}, &inviteRes); err != nil { })
switch e := err.(type) {
case roomserverAPI.ErrInvalidID:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
}
case roomserverAPI.ErrNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
}
case nil:
default:
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
sentry.CaptureException(err)
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
} }
} }
if inviteRes.Error != nil {
return inviteRes.Error.JSONResponse()
}
} }
} }
if r.Visibility == "public" { if r.Visibility == spec.Public {
// expose this room in the published room list // expose this room in the published room list
var pubRes roomserverAPI.PerformPublishResponse if err = rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
if err := rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: "public", Visibility: spec.Public,
}, &pubRes); err != nil { }); err != nil {
return jsonerror.InternalAPIError(ctx, err) util.GetLogger(ctx).WithError(err).Error("failed to publish room")
} return jsonerror.InternalServerError()
if pubRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public")
} }
} }

View file

@ -304,16 +304,12 @@ func SetVisibility(
return *reqErr return *reqErr
} }
var publishRes roomserverAPI.PerformPublishResponse if err = rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: v.Visibility, Visibility: v.Visibility,
}, &publishRes); err != nil { }); err != nil {
return jsonerror.InternalAPIError(req.Context(), err) util.GetLogger(req.Context()).WithError(err).Error("failed to publish room")
} return jsonerror.InternalServerError()
if publishRes.Error != nil {
util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed")
return publishRes.Error.JSONResponse()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -342,18 +338,14 @@ func SetVisibilityAS(
return *reqErr return *reqErr
} }
} }
var publishRes roomserverAPI.PerformPublishResponse
if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: v.Visibility, Visibility: v.Visibility,
NetworkID: networkID, NetworkID: networkID,
AppserviceID: dev.AppserviceID, AppserviceID: dev.AppserviceID,
}, &publishRes); err != nil { }); err != nil {
return jsonerror.InternalAPIError(req.Context(), err) util.GetLogger(req.Context()).WithError(err).Error("failed to publish room")
} return jsonerror.InternalServerError()
if publishRes.Error != nil {
util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed")
return publishRes.Error.JSONResponse()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -15,14 +15,18 @@
package routing package routing
import ( import (
"encoding/json"
"errors"
"net/http" "net/http"
"time" "time"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -41,7 +45,6 @@ func JoinRoomByIDOrAlias(
IsGuest: device.AccountType == api.AccountTypeGuest, IsGuest: device.AccountType == api.AccountTypeGuest,
Content: map[string]interface{}{}, Content: map[string]interface{}{},
} }
joinRes := roomserverAPI.PerformJoinResponse{}
// Check to see if any ?server_name= query parameters were // Check to see if any ?server_name= query parameters were
// given in the request. // given in the request.
@ -81,37 +84,66 @@ func JoinRoomByIDOrAlias(
done := make(chan util.JSONResponse, 1) done := make(chan util.JSONResponse, 1)
go func() { go func() {
defer close(done) defer close(done)
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { roomID, _, err := rsAPI.PerformJoin(req.Context(), &joinReq)
done <- jsonerror.InternalAPIError(req.Context(), err) var response util.JSONResponse
} else if joinRes.Error != nil {
if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { switch e := err.(type) {
done <- util.JSONResponse{ case nil: // success case
Code: http.StatusForbidden, response = util.JSONResponse{
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
}
} else {
done <- joinRes.Error.JSONResponse()
}
} else {
done <- util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
// TODO: Put the response struct somewhere internal. // TODO: Put the response struct somewhere internal.
JSON: struct { JSON: struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
}{joinRes.RoomID}, }{roomID},
}
case roomserverAPI.ErrInvalidID:
response = util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
}
case roomserverAPI.ErrNotAllowed:
jsonErr := jsonerror.Forbidden(e.Error())
if device.AccountType == api.AccountTypeGuest {
jsonErr = jsonerror.GuestAccessForbidden(e.Error())
}
response = util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonErr,
}
case *gomatrix.HTTPError: // this ensures we proxy responses over federation to the client
response = util.JSONResponse{
Code: e.Code,
JSON: json.RawMessage(e.Message),
}
default:
response = util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(),
}
if errors.Is(err, eventutil.ErrRoomNoExists) {
response = util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(e.Error()),
}
} }
} }
done <- response
}() }()
// Wait either for the join to finish, or for us to hit a reasonable // Wait either for the join to finish, or for us to hit a reasonable
// timeout, at which point we'll just return a 200 to placate clients. // timeout, at which point we'll just return a 200 to placate clients.
timer := time.NewTimer(time.Second * 20)
select { select {
case <-time.After(time.Second * 20): case <-timer.C:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusAccepted, Code: http.StatusAccepted,
JSON: jsonerror.Unknown("The room join will continue in the background."), JSON: jsonerror.Unknown("The room join will continue in the background."),
} }
case result := <-done: case result := <-done:
// Stop and drain the timer
if !timer.Stop() {
<-timer.C
}
return result return result
} }
} }

View file

@ -66,7 +66,6 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
Preset: presetPublicChat, Preset: presetPublicChat,
RoomAliasName: "alias", RoomAliasName: "alias",
Invite: []string{bob.ID}, Invite: []string{bob.ID},
GuestCanJoin: false,
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
crResp, ok := resp.JSON.(createRoomResponse) crResp, ok := resp.JSON.(createRoomResponse)
if !ok { if !ok {
@ -75,13 +74,12 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
// create a room with guest access enabled and invite Charlie // create a room with guest access enabled and invite Charlie
resp = createRoom(ctx, createRoomRequest{ resp = createRoom(ctx, createRoomRequest{
Name: "testing", Name: "testing",
IsDirect: true, IsDirect: true,
Topic: "testing", Topic: "testing",
Visibility: "public", Visibility: "public",
Preset: presetPublicChat, Preset: presetPublicChat,
Invite: []string{charlie.ID}, Invite: []string{charlie.ID},
GuestCanJoin: true,
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
if !ok { if !ok {

View file

@ -61,28 +61,26 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
var performKeyBackupResp userapi.PerformKeyBackupResponse if len(kb.AuthData) == 0 {
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("missing auth_data"),
}
}
version, err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
UserID: device.UserID, UserID: device.UserID,
Version: "", Version: "",
AuthData: kb.AuthData, AuthData: kb.AuthData,
Algorithm: kb.Algorithm, Algorithm: kb.Algorithm,
}, &performKeyBackupResp); err != nil { })
return jsonerror.InternalServerError() if err != nil {
} return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", err))
if performKeyBackupResp.Error != "" {
if performKeyBackupResp.BadInput {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
}
}
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
} }
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: keyBackupVersionCreateResponse{ JSON: keyBackupVersionCreateResponse{
Version: performKeyBackupResp.Version, Version: version,
}, },
} }
} }
@ -90,15 +88,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
// KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned. // KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned.
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} // Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID, UserID: device.UserID,
Version: version, Version: version,
}, &queryResp); err != nil { })
return jsonerror.InternalAPIError(req.Context(), err) if err != nil {
} return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", err))
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
} }
if !queryResp.Exists { if !queryResp.Exists {
return util.JSONResponse{ return util.JSONResponse{
@ -126,31 +121,29 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
var performKeyBackupResp userapi.PerformKeyBackupResponse performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
UserID: device.UserID, UserID: device.UserID,
Version: version, Version: version,
AuthData: kb.AuthData, AuthData: kb.AuthData,
Algorithm: kb.Algorithm, Algorithm: kb.Algorithm,
}, &performKeyBackupResp); err != nil { })
return jsonerror.InternalServerError() switch e := err.(type) {
} case *jsonerror.ErrRoomKeysVersion:
if performKeyBackupResp.Error != "" { return util.JSONResponse{
if performKeyBackupResp.BadInput { Code: http.StatusForbidden,
return util.JSONResponse{ JSON: e,
Code: 400,
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
}
} }
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) case nil:
default:
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e))
} }
if !performKeyBackupResp.Exists { if !performKeyBackupResp.Exists {
return util.JSONResponse{ return util.JSONResponse{
Code: 404, Code: 404,
JSON: jsonerror.NotFound("backup version not found"), JSON: jsonerror.NotFound("backup version not found"),
} }
} }
// Unclear what the 200 body should be
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: keyBackupVersionCreateResponse{ JSON: keyBackupVersionCreateResponse{
@ -162,35 +155,19 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse
// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK. // Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK.
// Implements DELETE /_matrix/client/r0/room_keys/version/{version} // Implements DELETE /_matrix/client/r0/room_keys/version/{version}
func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
var performKeyBackupResp userapi.PerformKeyBackupResponse exists, err := userAPI.DeleteKeyBackup(req.Context(), device.UserID, version)
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ if err != nil {
UserID: device.UserID, return util.ErrorResponse(fmt.Errorf("DeleteKeyBackup: %s", err))
Version: version,
DeleteBackup: true,
}, &performKeyBackupResp); err != nil {
return jsonerror.InternalServerError()
} }
if performKeyBackupResp.Error != "" { if !exists {
if performKeyBackupResp.BadInput {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
}
}
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
}
if !performKeyBackupResp.Exists {
return util.JSONResponse{ return util.JSONResponse{
Code: 404, Code: 404,
JSON: jsonerror.NotFound("backup version not found"), JSON: jsonerror.NotFound("backup version not found"),
} }
} }
// Unclear what the 200 body should be
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: keyBackupVersionCreateResponse{ JSON: struct{}{},
Version: performKeyBackupResp.Version,
},
} }
} }
@ -198,22 +175,21 @@ func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
func UploadBackupKeys( func UploadBackupKeys(
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest, req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest,
) util.JSONResponse { ) util.JSONResponse {
var performKeyBackupResp userapi.PerformKeyBackupResponse performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
UserID: device.UserID, UserID: device.UserID,
Version: version, Version: version,
Keys: *keys, Keys: *keys,
}, &performKeyBackupResp); err != nil && performKeyBackupResp.Error == "" { })
return jsonerror.InternalServerError()
} switch e := err.(type) {
if performKeyBackupResp.Error != "" { case *jsonerror.ErrRoomKeysVersion:
if performKeyBackupResp.BadInput { return util.JSONResponse{
return util.JSONResponse{ Code: http.StatusForbidden,
Code: 400, JSON: e,
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
}
} }
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) case nil:
default:
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e))
} }
if !performKeyBackupResp.Exists { if !performKeyBackupResp.Exists {
return util.JSONResponse{ return util.JSONResponse{
@ -234,18 +210,15 @@ func UploadBackupKeys(
func GetBackupKeys( func GetBackupKeys(
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string, req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string,
) util.JSONResponse { ) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID, UserID: device.UserID,
Version: version, Version: version,
ReturnKeys: true, ReturnKeys: true,
KeysForRoomID: roomID, KeysForRoomID: roomID,
KeysForSessionID: sessionID, KeysForSessionID: sessionID,
}, &queryResp); err != nil { })
return jsonerror.InternalAPIError(req.Context(), err) if err != nil {
} return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %w", err))
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
} }
if !queryResp.Exists { if !queryResp.Exists {
return util.JSONResponse{ return util.JSONResponse{
@ -267,17 +240,20 @@ func GetBackupKeys(
} }
} else if roomID != "" { } else if roomID != "" {
roomData, ok := queryResp.Keys[roomID] roomData, ok := queryResp.Keys[roomID]
if ok { if !ok {
// wrap response in "sessions" // If no keys are found, then an object with an empty sessions property will be returned
return util.JSONResponse{ roomData = make(map[string]userapi.KeyBackupSession)
Code: 200,
JSON: struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
}{
Sessions: roomData,
},
}
} }
// wrap response in "sessions"
return util.JSONResponse{
Code: 200,
JSON: struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
}{
Sessions: roomData,
},
}
} else { } else {
// response is the same as the upload request // response is the same as the upload request
var resp keyBackupSessionRequest var resp keyBackupSessionRequest

View file

@ -20,6 +20,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -31,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -91,7 +93,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
if err = roomserverAPI.SendEvents( if err = roomserverAPI.SendEvents(
ctx, rsAPI, ctx, rsAPI,
roomserverAPI.KindNew, roomserverAPI.KindNew,
[]*gomatrixserverlib.HeaderedEvent{event}, []*types.HeaderedEvent{event},
device.UserDomain(), device.UserDomain(),
serverName, serverName,
serverName, serverName,
@ -264,22 +266,33 @@ func sendInvite(
return jsonerror.InternalServerError(), err return jsonerror.InternalServerError(), err
} }
var inviteRes api.PerformInviteResponse err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
if err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
Event: event, Event: event,
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
RoomVersion: event.RoomVersion, RoomVersion: event.Version(),
SendAsServer: string(device.UserDomain()), SendAsServer: string(device.UserDomain()),
}, &inviteRes); err != nil { })
switch e := err.(type) {
case roomserverAPI.ErrInvalidID:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
}, e
case roomserverAPI.ErrNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
}, e
case nil:
default:
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
sentry.CaptureException(err)
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
}, err }, err
} }
if inviteRes.Error != nil {
return inviteRes.Error.JSONResponse(), inviteRes.Error
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -294,7 +307,7 @@ func buildMembershipEvent(
membership, roomID string, isDirect bool, membership, roomID string, isDirect bool,
cfg *config.ClientAPI, evTime time.Time, cfg *config.ClientAPI, evTime time.Time,
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*types.HeaderedEvent, error) {
profile, err := loadProfile(ctx, targetUserID, cfg, profileAPI, asAPI) profile, err := loadProfile(ctx, targetUserID, cfg, profileAPI, asAPI)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -15,13 +15,16 @@
package routing package routing
import ( import (
"encoding/json"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
func PeekRoomByIDOrAlias( func PeekRoomByIDOrAlias(
@ -41,8 +44,6 @@ func PeekRoomByIDOrAlias(
UserID: device.UserID, UserID: device.UserID,
DeviceID: device.ID, DeviceID: device.ID,
} }
peekRes := roomserverAPI.PerformPeekResponse{}
// Check to see if any ?server_name= query parameters were // Check to see if any ?server_name= query parameters were
// given in the request. // given in the request.
if serverNames, ok := req.URL.Query()["server_name"]; ok { if serverNames, ok := req.URL.Query()["server_name"]; ok {
@ -55,11 +56,27 @@ func PeekRoomByIDOrAlias(
} }
// Ask the roomserver to perform the peek. // Ask the roomserver to perform the peek.
if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil { roomID, err := rsAPI.PerformPeek(req.Context(), &peekReq)
return util.ErrorResponse(err) switch e := err.(type) {
} case roomserverAPI.ErrInvalidID:
if peekRes.Error != nil { return util.JSONResponse{
return peekRes.Error.JSONResponse() Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
}
case roomserverAPI.ErrNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
}
case *gomatrix.HTTPError:
return util.JSONResponse{
Code: e.Code,
JSON: json.RawMessage(e.Message),
}
case nil:
default:
logrus.WithError(err).WithField("roomID", roomIDOrAlias).Errorf("Failed to peek room")
return jsonerror.InternalServerError()
} }
// if this user is already joined to the room, we let them peek anyway // if this user is already joined to the room, we let them peek anyway
@ -75,7 +92,7 @@ func PeekRoomByIDOrAlias(
// TODO: Put the response struct somewhere internal. // TODO: Put the response struct somewhere internal.
JSON: struct { JSON: struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
}{peekRes.RoomID}, }{roomID},
} }
} }
@ -85,18 +102,17 @@ func UnpeekRoomByID(
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{ err := rsAPI.PerformUnpeek(req.Context(), roomID, device.UserID, device.ID)
RoomID: roomID, switch e := err.(type) {
UserID: device.UserID, case roomserverAPI.ErrInvalidID:
DeviceID: device.ID, return util.JSONResponse{
} Code: http.StatusBadRequest,
unpeekRes := roomserverAPI.PerformUnpeekResponse{} JSON: jsonerror.Unknown(e.Error()),
}
if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil { case nil:
return jsonerror.InternalAPIError(req.Context(), err) default:
} logrus.WithError(err).WithField("roomID", roomID).Errorf("Failed to un-peek room")
if unpeekRes.Error != nil { return jsonerror.InternalServerError()
return unpeekRes.Error.JSONResponse()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -334,16 +335,10 @@ func buildMembershipEvents(
roomIDs []string, roomIDs []string,
newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI,
evTime time.Time, rsAPI api.ClientRoomserverAPI, evTime time.Time, rsAPI api.ClientRoomserverAPI,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*types.HeaderedEvent, error) {
evs := []*gomatrixserverlib.HeaderedEvent{} evs := []*types.HeaderedEvent{}
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
return nil, err
}
builder := gomatrixserverlib.EventBuilder{ builder := gomatrixserverlib.EventBuilder{
Sender: userID, Sender: userID,
RoomID: roomID, RoomID: roomID,
@ -372,7 +367,7 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
evs = append(evs, event.Headered(verRes.RoomVersion)) evs = append(evs, event)
} }
return evs, nil return evs, nil

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/internal/transactions"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
@ -138,7 +139,7 @@ func SendRedaction(
} }
} }
domain := device.UserDomain() domain := device.UserDomain()
if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil { if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*types.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -162,13 +162,13 @@ func Setup(
dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}",
httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminEvacuateUser(req, cfg, rsAPI) return AdminEvacuateUser(req, rsAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}",
httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminPurgeRoom(req, cfg, device, rsAPI) return AdminPurgeRoom(req, rsAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
@ -180,7 +180,7 @@ func Setup(
dendriteAdminRouter.Handle("/admin/downloadState/{serverName}/{roomID}", dendriteAdminRouter.Handle("/admin/downloadState/{serverName}/{roomID}",
httputil.MakeAdminAPI("admin_download_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_download_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminDownloadState(req, cfg, device, rsAPI) return AdminDownloadState(req, device, rsAPI)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)

View file

@ -34,6 +34,7 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/internal/transactions"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
@ -76,9 +77,8 @@ func SendEvent(
rsAPI api.ClientRoomserverAPI, rsAPI api.ClientRoomserverAPI,
txnCache *transactions.Cache, txnCache *transactions.Cache,
) util.JSONResponse { ) util.JSONResponse {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} roomVersion, err := rsAPI.QueryRoomVersionForRoom(req.Context(), roomID)
verRes := api.QueryRoomVersionForRoomResponse{} if err != nil {
if err := rsAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()), JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
@ -184,8 +184,8 @@ func SendEvent(
if err := api.SendEvents( if err := api.SendEvents(
req.Context(), rsAPI, req.Context(), rsAPI,
api.KindNew, api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{ []*types.HeaderedEvent{
e.Headered(verRes.RoomVersion), &types.HeaderedEvent{Event: e},
}, },
device.UserDomain(), device.UserDomain(),
domain, domain,
@ -200,7 +200,7 @@ func SendEvent(
util.GetLogger(req.Context()).WithFields(logrus.Fields{ util.GetLogger(req.Context()).WithFields(logrus.Fields{
"event_id": e.EventID(), "event_id": e.EventID(),
"room_id": roomID, "room_id": roomID,
"room_version": verRes.RoomVersion, "room_version": roomVersion,
}).Info("Sent event to roomserver") }).Info("Sent event to roomserver")
res := util.JSONResponse{ res := util.JSONResponse{
@ -317,7 +317,7 @@ func generateSendEvent(
for i := range queryRes.StateEvents { for i := range queryRes.StateEvents {
stateEvents[i] = queryRes.StateEvents[i].Event stateEvents[i] = queryRes.StateEvents[i].Event
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(e.Event, &provider); err != nil { if err = gomatrixserverlib.Allowed(e.Event, &provider); err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

View file

@ -22,12 +22,12 @@ import (
"time" "time"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
@ -157,7 +157,6 @@ func SendServerNotice(
Visibility: "private", Visibility: "private",
Preset: presetPrivateChat, Preset: presetPrivateChat,
CreationContent: cc, CreationContent: cc,
GuestCanJoin: false,
RoomVersion: roomVersion, RoomVersion: roomVersion,
PowerLevelContentOverride: pl, PowerLevelContentOverride: pl,
} }
@ -228,8 +227,8 @@ func SendServerNotice(
if err := api.SendEvents( if err := api.SendEvents(
ctx, rsAPI, ctx, rsAPI,
api.KindNew, api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{ []*types.HeaderedEvent{
e.Headered(roomVersion), &types.HeaderedEvent{Event: e},
}, },
device.UserDomain(), device.UserDomain(),
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -133,7 +134,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
for _, ev := range stateRes.StateEvents { for _, ev := range stateRes.StateEvents {
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.HeaderedToClientEvent(ev, synctypes.FormatAll), synctypes.ToClientEvent(ev, synctypes.FormatAll),
) )
} }
} else { } else {
@ -152,7 +153,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
for _, ev := range stateAfterRes.StateEvents { for _, ev := range stateAfterRes.StateEvents {
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.HeaderedToClientEvent(ev, synctypes.FormatAll), synctypes.ToClientEvent(ev, synctypes.FormatAll),
) )
} }
} }
@ -266,7 +267,7 @@ func OnIncomingStateTypeRequest(
"state_at_event": !wantLatestState, "state_at_event": !wantLatestState,
}).Info("Fetching state") }).Info("Fetching state")
var event *gomatrixserverlib.HeaderedEvent var event *types.HeaderedEvent
if wantLatestState { if wantLatestState {
// If we are happy to use the latest state, either because the user is // If we are happy to use the latest state, either because the user is
// still in the room, or because the room is world-readable, then just // still in the room, or because the room is world-readable, then just
@ -311,7 +312,7 @@ func OnIncomingStateTypeRequest(
} }
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: synctypes.HeaderedToClientEvent(event, synctypes.FormatAll), ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll),
} }
var res interface{} var res interface{}

View file

@ -15,11 +15,13 @@
package routing package routing
import ( import (
"errors"
"net/http" "net/http"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -57,38 +59,28 @@ func UpgradeRoom(
} }
} }
upgradeReq := roomserverAPI.PerformRoomUpgradeRequest{ newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, device.UserID, gomatrixserverlib.RoomVersion(r.NewVersion))
UserID: device.UserID, switch e := err.(type) {
RoomID: roomID, case nil:
RoomVersion: gomatrixserverlib.RoomVersion(r.NewVersion), case roomserverAPI.ErrNotAllowed:
} return util.JSONResponse{
upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{} Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
if err := rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp); err != nil { }
return jsonerror.InternalAPIError(req.Context(), err) default:
} if errors.Is(err, eventutil.ErrRoomNoExists) {
if upgradeResp.Error != nil {
if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Room does not exist"), JSON: jsonerror.NotFound("Room does not exist"),
} }
} else if upgradeResp.Error.Code == roomserverAPI.PerformErrorNotAllowed {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(upgradeResp.Error.Msg),
}
} else {
return jsonerror.InternalServerError()
} }
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: upgradeRoomResponse{ JSON: upgradeRoomResponse{
ReplacementRoom: upgradeResp.NewRoomID, ReplacementRoom: newRoomID,
}, },
} }
} }

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -367,8 +368,8 @@ func emit3PIDInviteEvent(
return api.SendEvents( return api.SendEvents(
ctx, rsAPI, ctx, rsAPI,
api.KindNew, api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{ []*types.HeaderedEvent{
event.Headered(queryRes.RoomVersion), event,
}, },
device.UserDomain(), device.UserDomain(),
cfg.Matrix.ServerName, cfg.Matrix.ServerName,

View file

@ -183,8 +183,8 @@ func main() {
var resolved Events var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts( resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), gomatrixserverlib.RoomVersion(*roomVersion),
events, gomatrixserverlib.ToPDUs(events),
authEvents, gomatrixserverlib.ToPDUs(authEvents),
) )
if err != nil { if err != nil {
panic(err) panic(err)
@ -208,7 +208,7 @@ func main() {
fmt.Println("Returned", count, "state events after filtering") fmt.Println("Returned", count, "state events after filtering")
} }
type Events []*gomatrixserverlib.Event type Events []gomatrixserverlib.PDU
func (e Events) Len() int { func (e Events) Len() int {
return len(e) return len(e)

View file

@ -11,11 +11,13 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/federationapi/types"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
) )
// FederationInternalAPI is used to query information from the federation sender. // FederationInternalAPI is used to query information from the federation sender.
type FederationInternalAPI interface { type FederationInternalAPI interface {
gomatrixserverlib.FederatedStateClient gomatrixserverlib.FederatedStateClient
gomatrixserverlib.FederatedJoinClient
KeyserverFederationAPI KeyserverFederationAPI
gomatrixserverlib.KeyDatabase gomatrixserverlib.KeyDatabase
ClientFederationAPI ClientFederationAPI
@ -188,13 +190,13 @@ type PerformLeaveResponse struct {
} }
type PerformInviteRequest struct { type PerformInviteRequest struct {
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
Event *gomatrixserverlib.HeaderedEvent `json:"event"` Event *rstypes.HeaderedEvent `json:"event"`
InviteRoomState []fclient.InviteV2StrippedState `json:"invite_room_state"` InviteRoomState []fclient.InviteV2StrippedState `json:"invite_room_state"`
} }
type PerformInviteResponse struct { type PerformInviteResponse struct {
Event *gomatrixserverlib.HeaderedEvent `json:"event"` Event *rstypes.HeaderedEvent `json:"event"`
} }
// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames

View file

@ -187,7 +187,12 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
addsStateEvents = append(addsStateEvents, eventsRes.Events...) addsStateEvents = append(addsStateEvents, eventsRes.Events...)
} }
addsJoinedHosts, err := JoinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents)) evs := make([]*gomatrixserverlib.Event, len(addsStateEvents))
for i := range evs {
evs[i] = addsStateEvents[i].Event
}
addsJoinedHosts, err := JoinedHostsFromEvents(evs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/internal"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
@ -106,7 +107,9 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer spec.ServerN
return keys, nil return keys, nil
} }
func (f *fedClient) MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res fclient.RespMakeJoin, err error) { func (f *fedClient) MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string) (res fclient.RespMakeJoin, err error) {
f.fedClientMutex.Lock()
defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
if r.ID == roomID { if r.ID == roomID {
res.RoomVersion = r.Version res.RoomVersion = r.Version
@ -135,10 +138,10 @@ func (f *fedClient) SendJoin(ctx context.Context, origin, s spec.ServerName, eve
defer f.fedClientMutex.Unlock() defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
if r.ID == event.RoomID() { if r.ID == event.RoomID() {
r.InsertEvent(f.t, event.Headered(r.Version)) r.InsertEvent(f.t, &types.HeaderedEvent{Event: event})
f.t.Logf("Join event: %v", event.EventID()) f.t.Logf("Join event: %v", event.EventID())
res.StateEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.CurrentState()) res.StateEvents = types.NewEventJSONsFromHeaderedEvents(r.CurrentState())
res.AuthEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.Events()) res.AuthEvents = types.NewEventJSONsFromHeaderedEvents(r.Events())
} }
} }
return return
@ -325,8 +328,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("failed to parse event: %s", err) t.Errorf("failed to parse event: %s", err)
} }
he := ev.Headered(tc.roomVer) invReq, err := fclient.NewInviteV2Request(ev, nil)
invReq, err := fclient.NewInviteV2Request(he, nil)
if err != nil { if err != nil {
t.Errorf("failed to create invite v2 request: %s", err) t.Errorf("failed to create invite v2 request: %s", err)
continue continue

View file

@ -9,14 +9,40 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
const defaultTimeout = time.Second * 30
// Functions here are "proxying" calls to the gomatrixserverlib federation // Functions here are "proxying" calls to the gomatrixserverlib federation
// client. // client.
func (a *FederationInternalAPI) MakeJoin(
ctx context.Context, origin, s spec.ServerName, roomID, userID string,
) (res gomatrixserverlib.MakeJoinResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.federation.MakeJoin(ctx, origin, s, roomID, userID)
if err != nil {
return &fclient.RespMakeJoin{}, err
}
return &ires, nil
}
func (a *FederationInternalAPI) SendJoin(
ctx context.Context, origin, s spec.ServerName, event *gomatrixserverlib.Event,
) (res gomatrixserverlib.SendJoinResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.federation.SendJoin(ctx, origin, s, event)
if err != nil {
return &fclient.RespSendJoin{}, err
}
return &ires, nil
}
func (a *FederationInternalAPI) GetEventAuth( func (a *FederationInternalAPI) GetEventAuth(
ctx context.Context, origin, s spec.ServerName, ctx context.Context, origin, s spec.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (res fclient.RespEventAuth, err error) { ) (res fclient.RespEventAuth, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
@ -30,7 +56,7 @@ func (a *FederationInternalAPI) GetEventAuth(
func (a *FederationInternalAPI) GetUserDevices( func (a *FederationInternalAPI) GetUserDevices(
ctx context.Context, origin, s spec.ServerName, userID string, ctx context.Context, origin, s spec.ServerName, userID string,
) (fclient.RespUserDevices, error) { ) (fclient.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, origin, s, userID) return a.federation.GetUserDevices(ctx, origin, s, userID)
@ -44,7 +70,7 @@ func (a *FederationInternalAPI) GetUserDevices(
func (a *FederationInternalAPI) ClaimKeys( func (a *FederationInternalAPI) ClaimKeys(
ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string,
) (fclient.RespClaimKeys, error) { ) (fclient.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
@ -70,7 +96,7 @@ func (a *FederationInternalAPI) QueryKeys(
func (a *FederationInternalAPI) Backfill( func (a *FederationInternalAPI) Backfill(
ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
@ -84,7 +110,7 @@ func (a *FederationInternalAPI) Backfill(
func (a *FederationInternalAPI) LookupState( func (a *FederationInternalAPI) LookupState(
ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.StateResponse, err error) { ) (res gomatrixserverlib.StateResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
@ -99,7 +125,7 @@ func (a *FederationInternalAPI) LookupState(
func (a *FederationInternalAPI) LookupStateIDs( func (a *FederationInternalAPI) LookupStateIDs(
ctx context.Context, origin, s spec.ServerName, roomID, eventID string, ctx context.Context, origin, s spec.ServerName, roomID, eventID string,
) (res gomatrixserverlib.StateIDResponse, err error) { ) (res gomatrixserverlib.StateIDResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
@ -114,7 +140,7 @@ func (a *FederationInternalAPI) LookupMissingEvents(
ctx context.Context, origin, s spec.ServerName, roomID string, ctx context.Context, origin, s spec.ServerName, roomID string,
missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res fclient.RespMissingEvents, err error) { ) (res fclient.RespMissingEvents, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
@ -128,7 +154,7 @@ func (a *FederationInternalAPI) LookupMissingEvents(
func (a *FederationInternalAPI) GetEvent( func (a *FederationInternalAPI) GetEvent(
ctx context.Context, origin, s spec.ServerName, eventID string, ctx context.Context, origin, s spec.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, origin, s, eventID) return a.federation.GetEvent(ctx, origin, s, eventID)

View file

@ -18,6 +18,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/consumers" "github.com/matrix-org/dendrite/federationapi/consumers"
"github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/statistics"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
) )
@ -73,12 +74,6 @@ func (r *FederationInternalAPI) PerformJoin(
r.joins.Store(j, nil) r.joins.Store(j, nil)
defer r.joins.Delete(j) defer r.joins.Delete(j)
// Look up the supported room versions.
var supportedVersions []gomatrixserverlib.RoomVersion
for version := range version.SupportedRoomVersions() {
supportedVersions = append(supportedVersions, version)
}
// Deduplicate the server names we were provided but keep the ordering // Deduplicate the server names we were provided but keep the ordering
// as this encodes useful information about which servers are most likely // as this encodes useful information about which servers are most likely
// to respond. // to respond.
@ -103,7 +98,6 @@ func (r *FederationInternalAPI) PerformJoin(
request.UserID, request.UserID,
request.Content, request.Content,
serverName, serverName,
supportedVersions,
request.Unsigned, request.Unsigned,
); err != nil { ); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{ logrus.WithError(err).WithFields(logrus.Fields{
@ -146,128 +140,45 @@ func (r *FederationInternalAPI) performJoinUsingServer(
roomID, userID string, roomID, userID string,
content map[string]interface{}, content map[string]interface{},
serverName spec.ServerName, serverName spec.ServerName,
supportedVersions []gomatrixserverlib.RoomVersion,
unsigned map[string]interface{}, unsigned map[string]interface{},
) error { ) error {
if !r.shouldAttemptDirectFederation(serverName) { if !r.shouldAttemptDirectFederation(serverName) {
return fmt.Errorf("relay servers have no meaningful response for join.") return fmt.Errorf("relay servers have no meaningful response for join.")
} }
_, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) user, err := spec.NewUserID(userID, true)
if err != nil {
return err
}
room, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
return err return err
} }
// Try to perform a make_join using the information supplied in the joinInput := gomatrixserverlib.PerformJoinInput{
// request. UserID: user,
respMakeJoin, err := r.federation.MakeJoin( RoomID: room,
ctx, ServerName: serverName,
origin, Content: content,
serverName, Unsigned: unsigned,
roomID, PrivateKey: r.cfg.Matrix.PrivateKey,
userID, KeyID: r.cfg.Matrix.KeyID,
supportedVersions, KeyRing: r.keyRing,
) EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName),
if err != nil {
// TODO: Check if the user was not allowed to join the room.
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.MakeJoin: %w", err)
} }
r.statistics.ForServer(serverName).Success(statistics.SendDirect) response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
// Set all the fields to be what they should be, this should be a no-op if joinErr != nil {
// but it's possible that the remote server returned us something "odd" if !joinErr.Reachable {
respMakeJoin.JoinEvent.Type = spec.MRoomMember r.statistics.ForServer(joinErr.ServerName).Failure()
respMakeJoin.JoinEvent.Sender = userID } else {
respMakeJoin.JoinEvent.StateKey = &userID r.statistics.ForServer(joinErr.ServerName).Success(statistics.SendDirect)
respMakeJoin.JoinEvent.RoomID = roomID
respMakeJoin.JoinEvent.Redacts = ""
if content == nil {
content = map[string]interface{}{}
}
_ = json.Unmarshal(respMakeJoin.JoinEvent.Content, &content)
content["membership"] = spec.Join
if err = respMakeJoin.JoinEvent.SetContent(content); err != nil {
return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err)
}
if err = respMakeJoin.JoinEvent.SetUnsigned(struct{}{}); err != nil {
return fmt.Errorf("respMakeJoin.JoinEvent.SetUnsigned: %w", err)
}
// Work out if we support the room version that has been supplied in
// the make_join response.
// "If not provided, the room version is assumed to be either "1" or "2"."
// https://matrix.org/docs/spec/server_server/unstable#get-matrix-federation-v1-make-join-roomid-userid
if respMakeJoin.RoomVersion == "" {
respMakeJoin.RoomVersion = setDefaultRoomVersionFromJoinEvent(respMakeJoin.JoinEvent)
}
verImpl, err := gomatrixserverlib.GetRoomVersion(respMakeJoin.RoomVersion)
if err != nil {
return err
}
// Build the join event.
event, err := respMakeJoin.JoinEvent.Build(
time.Now(),
origin,
r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey,
respMakeJoin.RoomVersion,
)
if err != nil {
return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err)
}
// Try to perform a send_join using the newly built event.
respSendJoin, err := r.federation.SendJoin(
context.Background(),
origin,
serverName,
event,
)
if err != nil {
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.SendJoin: %w", err)
}
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
// If the remote server returned an event in the "event" key of
// the send_join request then we should use that instead. It may
// contain signatures that we don't know about.
if len(respSendJoin.Event) > 0 {
var remoteEvent *gomatrixserverlib.Event
remoteEvent, err = verImpl.NewEventFromUntrustedJSON(respSendJoin.Event)
if err == nil && isWellFormedMembershipEvent(
remoteEvent, roomID, userID,
) {
event = remoteEvent
} }
return joinErr.Err
} }
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
// Sanity-check the join response to ensure that it has a create if response == nil {
// event, that the room version is known, etc. return fmt.Errorf("Received nil response from gomatrixserverlib.PerformJoin")
authEvents := respSendJoin.AuthEvents.UntrustedEvents(respMakeJoin.RoomVersion)
if err = sanityCheckAuthChain(authEvents); err != nil {
return fmt.Errorf("sanityCheckAuthChain: %w", err)
}
// Process the join response in a goroutine. The idea here is
// that we'll try and wait for as long as possible for the work
// to complete, but if the client does give up waiting, we'll
// still continue to process the join anyway so that we don't
// waste the effort.
// TODO: Can we expand Check here to return a list of missing auth
// events rather than failing one at a time?
var respState gomatrixserverlib.StateResponse
respState, err = gomatrixserverlib.CheckSendJoinResponse(
context.Background(),
respMakeJoin.RoomVersion, &respSendJoin,
r.keyRing,
event,
federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName),
)
if err != nil {
return fmt.Errorf("respSendJoin.Check: %w", err)
} }
// We need to immediately update our list of joined hosts for this room now as we are technically // We need to immediately update our list of joined hosts for this room now as we are technically
@ -276,60 +187,33 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent // joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers") // to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
// The events are trusted now as we performed auth checks above. // The events are trusted now as we performed auth checks above.
joinedHosts, err := consumers.JoinedHostsFromEvents(respState.GetStateEvents().TrustedEvents(respMakeJoin.RoomVersion, false)) joinedHosts, err := consumers.JoinedHostsFromEvents(response.StateSnapshot.GetStateEvents().TrustedEvents(response.JoinEvent.Version(), false))
if err != nil { if err != nil {
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err) return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
} }
logrus.WithField("room", roomID).Infof("Joined federated room with %d hosts", len(joinedHosts)) logrus.WithField("room", roomID).Infof("Joined federated room with %d hosts", len(joinedHosts))
if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil { if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil {
return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err) return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err)
} }
// If we successfully performed a send_join above then the other // TODO: Can I change this to not take respState but instead just take an opaque list of events?
// server now thinks we're a part of the room. Send the newly
// returned state to the roomserver to update our local view.
if unsigned != nil {
event, err = event.SetUnsigned(unsigned)
if err != nil {
// non-fatal, log and continue
logrus.WithError(err).Errorf("Failed to set unsigned content")
}
}
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
context.Background(), context.Background(),
r.rsAPI, r.rsAPI,
origin, user.Domain(),
roomserverAPI.KindNew, roomserverAPI.KindNew,
respState, response.StateSnapshot,
event.Headered(respMakeJoin.RoomVersion), &types.HeaderedEvent{Event: response.JoinEvent},
serverName, serverName,
nil, nil,
false, false,
); err != nil { ); err != nil {
return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err) return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err)
} }
return nil return nil
} }
// isWellFormedMembershipEvent returns true if the event looks like a legitimate
// membership event.
func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string) bool {
if membership, err := event.Membership(); err != nil {
return false
} else if membership != spec.Join {
return false
}
if event.RoomID() != roomID {
return false
}
if !event.StateKeyEquals(userID) {
return false
}
return true
}
// PerformOutboundPeekRequest implements api.FederationInternalAPI // PerformOutboundPeekRequest implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformOutboundPeek( func (r *FederationInternalAPI) PerformOutboundPeek(
ctx context.Context, ctx context.Context,
@ -475,12 +359,12 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName), ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName),
) )
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
} }
if err = sanityCheckAuthChain(authEvents); err != nil { if err = checkEventsContainCreateEvent(authEvents); err != nil {
return fmt.Errorf("sanityCheckAuthChain: %w", err) return fmt.Errorf("sanityCheckAuthChain: %w", err)
} }
@ -505,7 +389,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(stateEvents), StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(stateEvents),
AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(authEvents), AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(authEvents),
}, },
respPeek.LatestEvent.Headered(respPeek.RoomVersion), &types.HeaderedEvent{Event: respPeek.LatestEvent},
serverName, serverName,
nil, nil,
false, false,
@ -652,7 +536,7 @@ func (r *FederationInternalAPI) PerformInvite(
"destination": destination, "destination": destination,
}).Info("Sending invite") }).Info("Sending invite")
inviteReq, err := fclient.NewInviteV2Request(request.Event, request.InviteRoomState) inviteReq, err := fclient.NewInviteV2Request(request.Event.Event, request.InviteRoomState)
if err != nil { if err != nil {
return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err)
} }
@ -670,7 +554,7 @@ func (r *FederationInternalAPI) PerformInvite(
if err != nil { if err != nil {
return fmt.Errorf("r.federation.SendInviteV2 failed to decode event response: %w", err) return fmt.Errorf("r.federation.SendInviteV2 failed to decode event response: %w", err)
} }
response.Event = inviteEvent.Headered(request.RoomVersion) response.Event = &types.HeaderedEvent{Event: inviteEvent}
return nil return nil
} }
@ -719,9 +603,9 @@ func (r *FederationInternalAPI) MarkServersAlive(destinations []spec.ServerName)
} }
} }
func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { func checkEventsContainCreateEvent(events []*gomatrixserverlib.Event) error {
// sanity check we have a create event and it has a known room version // sanity check we have a create event and it has a known room version
for _, ev := range authChain { for _, ev := range events {
if ev.Type() == spec.MRoomCreate && ev.StateKeyEquals("") { if ev.Type() == spec.MRoomCreate && ev.StateKeyEquals("") {
// make sure the room version is known // make sure the room version is known
content := ev.Content() content := ev.Content()
@ -739,52 +623,28 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error {
} }
knownVersions := gomatrixserverlib.RoomVersions() knownVersions := gomatrixserverlib.RoomVersions()
if _, ok := knownVersions[gomatrixserverlib.RoomVersion(verBody.Version)]; !ok { if _, ok := knownVersions[gomatrixserverlib.RoomVersion(verBody.Version)]; !ok {
return fmt.Errorf("auth chain m.room.create event has an unknown room version: %s", verBody.Version) return fmt.Errorf("m.room.create event has an unknown room version: %s", verBody.Version)
} }
return nil return nil
} }
} }
return fmt.Errorf("auth chain response is missing m.room.create event") return fmt.Errorf("response is missing m.room.create event")
} }
func setDefaultRoomVersionFromJoinEvent( // federatedEventProvider is an event provider which fetches events from the server provided
joinEvent gomatrixserverlib.EventBuilder, func federatedEventProvider(
) gomatrixserverlib.RoomVersion {
// if auth events are not event references we know it must be v3+
// we have to do these shenanigans to satisfy sytest, specifically for:
// "Outbound federation rejects m.room.create events with an unknown room version"
hasEventRefs := true
authEvents, ok := joinEvent.AuthEvents.([]interface{})
if ok {
if len(authEvents) > 0 {
_, ok = authEvents[0].(string)
if ok {
// event refs are objects, not strings, so we know we must be dealing with a v3+ room.
hasEventRefs = false
}
}
}
if hasEventRefs {
return gomatrixserverlib.RoomVersionV1
}
return gomatrixserverlib.RoomVersionV4
}
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider(
ctx context.Context, federation fclient.FederationClient, ctx context.Context, federation fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
) gomatrixserverlib.AuthChainProvider { ) gomatrixserverlib.EventProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join. // the auth events supplied in the send_join.
retries := map[string][]*gomatrixserverlib.Event{} retries := map[string][]gomatrixserverlib.PDU{}
// Define a function which we can pass to Check to retrieve missing // Define a function which we can pass to Check to retrieve missing
// auth events inline. This greatly increases our chances of not having // auth events inline. This greatly increases our chances of not having
// to repeat the entire set of checks just for a missing event or two. // to repeat the entire set of checks just for a missing event or two.
return func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) { return func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.PDU, error) {
returning := []*gomatrixserverlib.Event{} returning := []gomatrixserverlib.PDU{}
verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
@ -824,7 +684,7 @@ func federatedAuthProvider(
} }
// Check the signatures of the event. // Check the signatures of the event.
if err := ev.VerifyEventSignatures(ctx, keyRing); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing); err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
} }

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
) )
@ -71,7 +72,7 @@ type destinationQueue struct {
// Send event adds the event to the pending queue for the destination. // Send event adds the event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to // If the queue is empty then it starts a background goroutine to
// start sending events to that destination. // start sending events to that destination.
func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) { func (oq *destinationQueue) sendEvent(event *types.HeaderedEvent, dbReceipt *receipt.Receipt) {
if event == nil { if event == nil {
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return return

View file

@ -33,6 +33,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
) )
@ -140,7 +141,7 @@ func NewOutgoingQueues(
type queuedPDU struct { type queuedPDU struct {
dbReceipt *receipt.Receipt dbReceipt *receipt.Receipt
pdu *gomatrixserverlib.HeaderedEvent pdu *types.HeaderedEvent
} }
type queuedEDU struct { type queuedEDU struct {
@ -187,7 +188,7 @@ func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
// SendEvent sends an event to the destinations // SendEvent sends an event to the destinations
func (oqs *OutgoingQueues) SendEvent( func (oqs *OutgoingQueues) SendEvent(
ev *gomatrixserverlib.HeaderedEvent, origin spec.ServerName, ev *types.HeaderedEvent, origin spec.ServerName,
destinations []spec.ServerName, destinations []spec.ServerName,
) error { ) error {
if oqs.disabled { if oqs.disabled {

View file

@ -35,6 +35,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
@ -101,14 +102,14 @@ func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u
return fclient.EmptyResp{}, result return fclient.EmptyResp{}, result
} }
func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { func mustCreatePDU(t *testing.T) *types.HeaderedEvent {
t.Helper() t.Helper()
content := `{"type":"m.room.message"}` content := `{"type":"m.room.message"}`
ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV10).NewEventFromTrustedJSON([]byte(content), false) ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV10).NewEventFromTrustedJSON([]byte(content), false)
if err != nil { if err != nil {
t.Fatalf("failed to create event: %v", err) t.Fatalf("failed to create event: %v", err)
} }
return ev.Headered(gomatrixserverlib.RoomVersionV10) return &types.HeaderedEvent{Event: ev}
} }
func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU {

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
@ -104,7 +105,7 @@ func Backfill(
// Filter any event that's not from the requested room out. // Filter any event that's not from the requested room out.
evs := make([]*gomatrixserverlib.Event, 0) evs := make([]*gomatrixserverlib.Event, 0)
var ev *gomatrixserverlib.HeaderedEvent var ev *types.HeaderedEvent
for _, ev = range res.Events { for _, ev = range res.Events {
if ev.RoomID() == roomID { if ev.RoomID() == roomID {
evs = append(evs, ev.Event) evs = append(evs, ev.Event)
@ -113,7 +114,7 @@ func Backfill(
eventJSONs := []json.RawMessage{} eventJSONs := []json.RawMessage{}
for _, e := range gomatrixserverlib.ReverseTopologicalOrdering( for _, e := range gomatrixserverlib.ReverseTopologicalOrdering(
evs, gomatrixserverlib.ToPDUs(evs),
gomatrixserverlib.TopologicalOrderByPrevEvents, gomatrixserverlib.TopologicalOrderByPrevEvents,
) { ) {
eventJSONs = append(eventJSONs, e.JSON()) eventJSONs = append(eventJSONs, e.JSON())

View file

@ -18,7 +18,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -72,7 +72,7 @@ func GetEventAuth(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: fclient.RespEventAuth{ JSON: fclient.RespEventAuth{
AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), AuthEvents: types.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents),
}, },
} }
} }

View file

@ -20,8 +20,10 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
@ -196,25 +198,44 @@ func processInvite(
) )
// Add the invite event to the roomserver. // Add the invite event to the roomserver.
inviteEvent := signedEvent.Headered(roomVer) inviteEvent := &types.HeaderedEvent{Event: &signedEvent}
request := &api.PerformInviteRequest{ request := &api.PerformInviteRequest{
Event: inviteEvent, Event: inviteEvent,
InviteRoomState: strippedState, InviteRoomState: strippedState,
RoomVersion: inviteEvent.RoomVersion, RoomVersion: inviteEvent.Version(),
SendAsServer: string(api.DoNotSendToOtherServers), SendAsServer: string(api.DoNotSendToOtherServers),
TransactionID: nil, TransactionID: nil,
} }
response := &api.PerformInviteResponse{}
if err := rsAPI.PerformInvite(ctx, request, response); err != nil { if err = rsAPI.PerformInvite(ctx, request); err != nil {
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
} }
} }
if response.Error != nil {
return response.Error.JSONResponse() switch e := err.(type) {
case api.ErrInvalidID:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
}
case api.ErrNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
}
case nil:
default:
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
sentry.CaptureException(err)
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(),
}
} }
// Return the signed event to the originating server, it should then tell // Return the signed event to the originating server, it should then tell
// the other servers in the room that we have been invited. // the other servers in the room that we have been invited.
if isInviteV2 { if isInviteV2 {

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
@ -42,9 +43,8 @@ func MakeJoin(
roomID, userID string, roomID, userID string,
remoteVersions []gomatrixserverlib.RoomVersion, remoteVersions []gomatrixserverlib.RoomVersion,
) util.JSONResponse { ) util.JSONResponse {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID)
verRes := api.QueryRoomVersionForRoomResponse{} if err != nil {
if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
@ -57,7 +57,7 @@ func MakeJoin(
// https://matrix.org/docs/spec/server_server/r0.1.3#get-matrix-federation-v1-make-join-roomid-userid // https://matrix.org/docs/spec/server_server/r0.1.3#get-matrix-federation-v1-make-join-roomid-userid
remoteSupportsVersion := false remoteSupportsVersion := false
for _, v := range remoteVersions { for _, v := range remoteVersions {
if v == verRes.RoomVersion { if v == roomVersion {
remoteSupportsVersion = true remoteSupportsVersion = true
break break
} }
@ -66,7 +66,7 @@ func MakeJoin(
if !remoteSupportsVersion { if !remoteSupportsVersion {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.IncompatibleRoomVersion(verRes.RoomVersion), JSON: jsonerror.IncompatibleRoomVersion(roomVersion),
} }
} }
@ -109,7 +109,7 @@ func MakeJoin(
// Check if the restricted join is allowed. If the room doesn't // Check if the restricted join is allowed. If the room doesn't
// support restricted joins then this is effectively a no-op. // support restricted joins then this is effectively a no-op.
res, authorisedVia, err := checkRestrictedJoin(httpReq, rsAPI, verRes.RoomVersion, roomID, userID) res, authorisedVia, err := checkRestrictedJoin(httpReq, rsAPI, roomVersion, roomID, userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("checkRestrictedJoin failed") util.GetLogger(httpReq.Context()).WithError(err).Error("checkRestrictedJoin failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -144,7 +144,7 @@ func MakeJoin(
} }
queryRes := api.QueryLatestEventsAndStateResponse{ queryRes := api.QueryLatestEventsAndStateResponse{
RoomVersion: verRes.RoomVersion, RoomVersion: roomVersion,
} }
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
@ -168,7 +168,7 @@ func MakeJoin(
stateEvents[i] = queryRes.StateEvents[i].Event stateEvents[i] = queryRes.StateEvents[i].Event
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(event.Event, &provider); err != nil { if err = gomatrixserverlib.Allowed(event.Event, &provider); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -180,7 +180,7 @@ func MakeJoin(
Code: http.StatusOK, Code: http.StatusOK,
JSON: map[string]interface{}{ JSON: map[string]interface{}{
"event": builder, "event": builder,
"room_version": verRes.RoomVersion, "room_version": roomVersion,
}, },
} }
} }
@ -197,21 +197,20 @@ func SendJoin(
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
roomID, eventID string, roomID, eventID string,
) util.JSONResponse { ) util.JSONResponse {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID)
verRes := api.QueryRoomVersionForRoomResponse{} if err != nil {
if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
} }
} }
verImpl, err := gomatrixserverlib.GetRoomVersion(verRes.RoomVersion) verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.UnsupportedRoomVersion( JSON: jsonerror.UnsupportedRoomVersion(
fmt.Sprintf("QueryRoomVersionForRoom returned unknown room version: %s", verRes.RoomVersion), fmt.Sprintf("QueryRoomVersionForRoom returned unknown room version: %s", roomVersion),
), ),
} }
} }
@ -415,7 +414,7 @@ func SendJoin(
InputRoomEvents: []api.InputRoomEvent{ InputRoomEvents: []api.InputRoomEvent{
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: signed.Headered(stateAndAuthChainResponse.RoomVersion), Event: &types.HeaderedEvent{Event: &signed},
SendAsServer: string(cfg.Matrix.ServerName), SendAsServer: string(cfg.Matrix.ServerName),
TransactionID: nil, TransactionID: nil,
}, },
@ -445,8 +444,8 @@ func SendJoin(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: fclient.RespSendJoin{ JSON: fclient.RespSendJoin{
StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents), StateEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents),
AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents), AuthEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents),
Origin: cfg.Matrix.ServerName, Origin: cfg.Matrix.ServerName,
Event: signed.JSON(), Event: signed.JSON(),
}, },
@ -521,7 +520,7 @@ func checkRestrictedJoin(
} }
} }
type eventsByDepth []*gomatrixserverlib.HeaderedEvent type eventsByDepth []*types.HeaderedEvent
func (e eventsByDepth) Len() int { func (e eventsByDepth) Len() int {
return len(e) return len(e)

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
@ -101,7 +102,7 @@ func MakeLeave(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: map[string]interface{}{ JSON: map[string]interface{}{
"room_version": event.RoomVersion, "room_version": event.Version(),
"event": state, "event": state,
}, },
} }
@ -113,7 +114,7 @@ func MakeLeave(
for i := range queryRes.StateEvents { for i := range queryRes.StateEvents {
stateEvents[i] = queryRes.StateEvents[i].Event stateEvents[i] = queryRes.StateEvents[i].Event
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(event.Event, &provider); err != nil { if err = gomatrixserverlib.Allowed(event.Event, &provider); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -124,7 +125,7 @@ func MakeLeave(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: map[string]interface{}{ JSON: map[string]interface{}{
"room_version": event.RoomVersion, "room_version": event.Version(),
"event": builder, "event": builder,
}, },
} }
@ -140,21 +141,20 @@ func SendLeave(
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
roomID, eventID string, roomID, eventID string,
) util.JSONResponse { ) util.JSONResponse {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID)
verRes := api.QueryRoomVersionForRoomResponse{} if err != nil {
if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()), JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
} }
} }
verImpl, err := gomatrixserverlib.GetRoomVersion(verRes.RoomVersion) verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.UnsupportedRoomVersion( JSON: jsonerror.UnsupportedRoomVersion(
fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", verRes.RoomVersion), fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", roomVersion),
), ),
} }
} }
@ -313,7 +313,7 @@ func SendLeave(
InputRoomEvents: []api.InputRoomEvent{ InputRoomEvents: []api.InputRoomEvent{
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event.Headered(verRes.RoomVersion), Event: &types.HeaderedEvent{Event: event},
SendAsServer: string(cfg.Matrix.ServerName), SendAsServer: string(cfg.Matrix.ServerName),
TransactionID: nil, TransactionID: nil,
}, },

View file

@ -18,7 +18,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -69,7 +69,7 @@ func GetMissingEvents(
eventsResponse.Events = filterEvents(eventsResponse.Events, roomID) eventsResponse.Events = filterEvents(eventsResponse.Events, roomID)
resp := fclient.RespMissingEvents{ resp := fclient.RespMissingEvents{
Events: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(eventsResponse.Events), Events: types.NewEventJSONsFromHeaderedEvents(eventsResponse.Events),
} }
return util.JSONResponse{ return util.JSONResponse{
@ -80,8 +80,8 @@ func GetMissingEvents(
// filterEvents returns only those events with matching roomID // filterEvents returns only those events with matching roomID
func filterEvents( func filterEvents(
events []*gomatrixserverlib.HeaderedEvent, roomID string, events []*types.HeaderedEvent, roomID string,
) []*gomatrixserverlib.HeaderedEvent { ) []*types.HeaderedEvent {
ref := events[:0] ref := events[:0]
for _, ev := range events { for _, ev := range events {
if ev.RoomID() == roomID { if ev.RoomID() == roomID {

View file

@ -19,6 +19,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
@ -35,10 +36,8 @@ func Peek(
remoteVersions []gomatrixserverlib.RoomVersion, remoteVersions []gomatrixserverlib.RoomVersion,
) util.JSONResponse { ) util.JSONResponse {
// TODO: check if we're just refreshing an existing peek by querying the federationapi // TODO: check if we're just refreshing an existing peek by querying the federationapi
roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID)
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} if err != nil {
verRes := api.QueryRoomVersionForRoomResponse{}
if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
@ -50,7 +49,7 @@ func Peek(
// the peek URL. // the peek URL.
remoteSupportsVersion := false remoteSupportsVersion := false
for _, v := range remoteVersions { for _, v := range remoteVersions {
if v == verRes.RoomVersion { if v == roomVersion {
remoteSupportsVersion = true remoteSupportsVersion = true
break break
} }
@ -59,7 +58,7 @@ func Peek(
if !remoteSupportsVersion { if !remoteSupportsVersion {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.IncompatibleRoomVersion(verRes.RoomVersion), JSON: jsonerror.IncompatibleRoomVersion(roomVersion),
} }
} }
@ -69,7 +68,7 @@ func Peek(
renewalInterval := int64(60 * 60 * 1000 * 1000) renewalInterval := int64(60 * 60 * 1000 * 1000)
var response api.PerformInboundPeekResponse var response api.PerformInboundPeekResponse
err := rsAPI.PerformInboundPeek( err = rsAPI.PerformInboundPeek(
httpReq.Context(), httpReq.Context(),
&api.PerformInboundPeekRequest{ &api.PerformInboundPeekRequest{
RoomID: roomID, RoomID: roomID,
@ -89,10 +88,10 @@ func Peek(
} }
respPeek := fclient.RespPeek{ respPeek := fclient.RespPeek{
StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.StateEvents), StateEvents: types.NewEventJSONsFromHeaderedEvents(response.StateEvents),
AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), AuthEvents: types.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents),
RoomVersion: response.RoomVersion, RoomVersion: response.RoomVersion,
LatestEvent: response.LatestEvent.Unwrap(), LatestEvent: response.LatestEvent.Event,
RenewalInterval: renewalInterval, RenewalInterval: renewalInterval,
} }

View file

@ -19,7 +19,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -42,8 +42,8 @@ func GetState(
} }
return util.JSONResponse{Code: http.StatusOK, JSON: &fclient.RespState{ return util.JSONResponse{Code: http.StatusOK, JSON: &fclient.RespState{
AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(authChain), AuthEvents: types.NewEventJSONsFromHeaderedEvents(authChain),
StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateEvents), StateEvents: types.NewEventJSONsFromHeaderedEvents(stateEvents),
}} }}
} }
@ -101,7 +101,7 @@ func getState(
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
roomID string, roomID string,
eventID string, eventID string,
) (stateEvents, authEvents []*gomatrixserverlib.HeaderedEvent, errRes *util.JSONResponse) { ) (stateEvents, authEvents []*types.HeaderedEvent, errRes *util.JSONResponse) {
// If we don't think we belong to this room then don't waste the effort // If we don't think we belong to this room then don't waste the effort
// responding to expensive requests for it. // responding to expensive requests for it.
if err := ErrorIfLocalServerNotInRoom(ctx, rsAPI, roomID); err != nil { if err := ErrorIfLocalServerNotInRoom(ctx, rsAPI, roomID); err != nil {
@ -157,7 +157,7 @@ func getState(
return response.StateEvents, response.AuthChainEvents, nil return response.StateEvents, response.AuthChainEvents, nil
} }
func getIDsFromEvent(events []*gomatrixserverlib.HeaderedEvent) []string { func getIDsFromEvent(events []*types.HeaderedEvent) []string {
IDs := make([]string, len(events)) IDs := make([]string, len(events))
for i := range events { for i := range events {
IDs[i] = events[i].EventID() IDs[i] = events[i].EventID()

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
@ -67,11 +68,10 @@ func CreateInvitesFrom3PIDInvites(
return *reqErr return *reqErr
} }
evs := []*gomatrixserverlib.HeaderedEvent{} evs := []*types.HeaderedEvent{}
for _, inv := range body.Invites { for _, inv := range body.Invites {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} _, err := rsAPI.QueryRoomVersionForRoom(req.Context(), inv.RoomID)
verRes := api.QueryRoomVersionForRoomResponse{} if err != nil {
if err := rsAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()), JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
@ -86,7 +86,7 @@ func CreateInvitesFrom3PIDInvites(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if event != nil { if event != nil {
evs = append(evs, event.Headered(verRes.RoomVersion)) evs = append(evs, &types.HeaderedEvent{Event: event})
} }
} }
@ -162,9 +162,8 @@ func ExchangeThirdPartyInvite(
} }
} }
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID)
verRes := api.QueryRoomVersionForRoomResponse{} if err != nil {
if err = rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.UnsupportedRoomVersion(err.Error()), JSON: jsonerror.UnsupportedRoomVersion(err.Error()),
@ -185,7 +184,7 @@ func ExchangeThirdPartyInvite(
// Ask the requesting server to sign the newly created event so we know it // Ask the requesting server to sign the newly created event so we know it
// acknowledged it // acknowledged it
inviteReq, err := fclient.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil) inviteReq, err := fclient.NewInviteV2Request(event, nil)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -195,9 +194,9 @@ func ExchangeThirdPartyInvite(
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
verImpl, err := gomatrixserverlib.GetRoomVersion(verRes.RoomVersion) verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Errorf("unknown room version: %s", verRes.RoomVersion) util.GetLogger(httpReq.Context()).WithError(err).Errorf("unknown room version: %s", roomVersion)
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
inviteEvent, err := verImpl.NewEventFromUntrustedJSON(signedEvent.Event) inviteEvent, err := verImpl.NewEventFromUntrustedJSON(signedEvent.Event)
@ -210,8 +209,8 @@ func ExchangeThirdPartyInvite(
if err = api.SendEvents( if err = api.SendEvents(
httpReq.Context(), rsAPI, httpReq.Context(), rsAPI,
api.KindNew, api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{ []*types.HeaderedEvent{
inviteEvent.Headered(verRes.RoomVersion), {Event: inviteEvent},
}, },
request.Destination(), request.Destination(),
request.Origin(), request.Origin(),
@ -239,12 +238,6 @@ func createInviteFrom3PIDInvite(
inv invite, federation fclient.FederationClient, inv invite, federation fclient.FederationClient,
userAPI userapi.FederationUserAPI, userAPI userapi.FederationUserAPI,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
return nil, err
}
_, server, err := gomatrixserverlib.SplitID('@', inv.MXID) _, server, err := gomatrixserverlib.SplitID('@', inv.MXID)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/federationapi/types"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
) )
type Database interface { type Database interface {
@ -38,7 +39,7 @@ type Database interface {
StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error) StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error)
GetPendingPDUs(ctx context.Context, serverName spec.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) GetPendingPDUs(ctx context.Context, serverName spec.ServerName, limit int) (pdus map[*receipt.Receipt]*rstypes.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName spec.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) GetPendingEDUs(ctx context.Context, serverName spec.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestinations(ctx context.Context, destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt) error AssociatePDUWithDestinations(ctx context.Context, destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt) error

View file

@ -22,7 +22,7 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
@ -56,7 +56,7 @@ func (d *Database) GetPendingPDUs(
serverName spec.ServerName, serverName spec.ServerName,
limit int, limit int,
) ( ) (
events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, events map[*receipt.Receipt]*types.HeaderedEvent,
err error, err error,
) { ) {
// Strictly speaking this doesn't need to be using the writer // Strictly speaking this doesn't need to be using the writer
@ -64,7 +64,7 @@ func (d *Database) GetPendingPDUs(
// a guarantee of transactional isolation, it's actually useful // a guarantee of transactional isolation, it's actually useful
// to know in SQLite mode that nothing else is trying to modify // to know in SQLite mode that nothing else is trying to modify
// the database. // the database.
events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) events = make(map[*receipt.Receipt]*types.HeaderedEvent)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit) nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit)
if err != nil { if err != nil {
@ -87,7 +87,7 @@ func (d *Database) GetPendingPDUs(
} }
for nid, blob := range blobs { for nid, blob := range blobs {
var event gomatrixserverlib.HeaderedEvent var event types.HeaderedEvent
if err := json.Unmarshal(blob, &event); err != nil { if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err) return fmt.Errorf("json.Unmarshal: %w", err)
} }

10
go.mod
View file

@ -18,11 +18,11 @@ require (
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/kardianos/minwinsvc v1.0.2 github.com/kardianos/minwinsvc v1.0.2
github.com/lib/pq v1.10.7 github.com/lib/pq v1.10.8
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230424155704-8daeaebaa0bc github.com/matrix-org/gomatrixserverlib v0.0.0-20230428192809-ff52c27efdce
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16
@ -35,7 +35,7 @@ require (
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_golang v1.13.0
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.2
github.com/tidwall/gjson v1.14.4 github.com/tidwall/gjson v1.14.4
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-client-go v2.30.0+incompatible
@ -49,6 +49,7 @@ require (
gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0 gotest.tools/v3 v3.4.0
maunium.net/go/mautrix v0.15.1
modernc.org/sqlite v1.19.3 modernc.org/sqlite v1.19.3
nhooyr.io/websocket v1.8.7 nhooyr.io/websocket v1.8.7
) )
@ -93,6 +94,7 @@ require (
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/klauspost/compress v1.16.0 // indirect github.com/klauspost/compress v1.16.0 // indirect
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.16 // indirect github.com/mattn/go-isatty v0.0.16 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
github.com/minio/highwayhash v1.0.2 // indirect github.com/minio/highwayhash v1.0.2 // indirect
@ -117,6 +119,7 @@ require (
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
github.com/quic-go/quic-go v0.32.0 // indirect github.com/quic-go/quic-go v0.32.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa // indirect github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa // indirect
github.com/rs/zerolog v1.29.1 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
go.etcd.io/bbolt v1.3.6 // indirect go.etcd.io/bbolt v1.3.6 // indirect
@ -131,6 +134,7 @@ require (
gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/macaroon.v2 v2.1.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
lukechampine.com/uint128 v1.2.0 // indirect lukechampine.com/uint128 v1.2.0 // indirect
maunium.net/go/maulogger/v2 v2.4.1 // indirect
modernc.org/cc/v3 v3.40.0 // indirect modernc.org/cc/v3 v3.40.0 // indirect
modernc.org/ccgo/v3 v3.16.13-0.20221017192402-261537637ce8 // indirect modernc.org/ccgo/v3 v3.16.13-0.20221017192402-261537637ce8 // indirect
modernc.org/libc v1.21.4 // indirect modernc.org/libc v1.21.4 // indirect

27
go.sum
View file

@ -123,6 +123,7 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/codeclysm/extract v2.2.0+incompatible h1:q3wyckoA30bhUSiwdQezMqVhwd8+WGE64/GL//LtUhI= github.com/codeclysm/extract v2.2.0+incompatible h1:q3wyckoA30bhUSiwdQezMqVhwd8+WGE64/GL//LtUhI=
github.com/codeclysm/extract v2.2.0+incompatible/go.mod h1:2nhFMPHiU9At61hz+12bfrlpXSUrOnK+wR+KlGO4Uks= github.com/codeclysm/extract v2.2.0+incompatible/go.mod h1:2nhFMPHiU9At61hz+12bfrlpXSUrOnK+wR+KlGO4Uks=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -188,6 +189,7 @@ github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm
github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
@ -313,21 +315,25 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.8 h1:3fdt97i/cwSU83+E0hZTC/Xpc9mTZxc6UWSCRcSbxiE=
github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.8/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM=
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg=
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA=
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230424155704-8daeaebaa0bc h1:F73iHhpTZxWVO6qbyGZxd7Ch44v1gK6xNQZ7QVos/Es= github.com/matrix-org/gomatrixserverlib v0.0.0-20230428192809-ff52c27efdce h1:ZdNs5Qj1Cf42GfwUE01oPIZccSiaPJ/HcZP9qxHte8k=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230424155704-8daeaebaa0bc/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230428192809-ff52c27efdce/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
@ -430,6 +436,9 @@ github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa/go.mod h1:qq
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc=
github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU=
github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8= github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
@ -453,8 +462,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@ -650,6 +659,8 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -848,6 +859,10 @@ honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI=
lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
maunium.net/go/mautrix v0.15.1 h1:pmCtMjYRpd83+2UL+KTRFYQo5to0373yulimvLK+1k0=
maunium.net/go/mautrix v0.15.1/go.mod h1:icQIrvz2NldkRLTuzSGzmaeuMUmw+fzO7UVycPeauN8=
modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw=
modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0=
modernc.org/ccgo/v3 v3.16.13-0.20221017192402-261537637ce8 h1:0+dsXf0zeLx9ixj4nilg6jKe5Bg1ilzBwSFq4kJmIUc= modernc.org/ccgo/v3 v3.16.13-0.20221017192402-261537637ce8 h1:0+dsXf0zeLx9ixj4nilg6jKe5Bg1ilzBwSFq4kJmIUc=

View file

@ -1,14 +1,15 @@
package caching package caching
import ( import (
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// FederationCache contains the subset of functions needed for // FederationCache contains the subset of functions needed for
// a federation event cache. // a federation event cache.
type FederationCache interface { type FederationCache interface {
GetFederationQueuedPDU(eventNID int64) (event *gomatrixserverlib.HeaderedEvent, ok bool) GetFederationQueuedPDU(eventNID int64) (event *types.HeaderedEvent, ok bool)
StoreFederationQueuedPDU(eventNID int64, event *gomatrixserverlib.HeaderedEvent) StoreFederationQueuedPDU(eventNID int64, event *types.HeaderedEvent)
EvictFederationQueuedPDU(eventNID int64) EvictFederationQueuedPDU(eventNID int64)
GetFederationQueuedEDU(eventNID int64) (event *gomatrixserverlib.EDU, ok bool) GetFederationQueuedEDU(eventNID int64) (event *gomatrixserverlib.EDU, ok bool)
@ -16,11 +17,11 @@ type FederationCache interface {
EvictFederationQueuedEDU(eventNID int64) EvictFederationQueuedEDU(eventNID int64)
} }
func (c Caches) GetFederationQueuedPDU(eventNID int64) (*gomatrixserverlib.HeaderedEvent, bool) { func (c Caches) GetFederationQueuedPDU(eventNID int64) (*types.HeaderedEvent, bool) {
return c.FederationPDUs.Get(eventNID) return c.FederationPDUs.Get(eventNID)
} }
func (c Caches) StoreFederationQueuedPDU(eventNID int64, event *gomatrixserverlib.HeaderedEvent) { func (c Caches) StoreFederationQueuedPDU(eventNID int64, event *types.HeaderedEvent) {
c.FederationPDUs.Set(eventNID, event) c.FederationPDUs.Set(eventNID, event)
} }

View file

@ -33,7 +33,7 @@ type Caches struct {
RoomServerStateKeyNIDs Cache[string, types.EventStateKeyNID] // event state key -> eventStateKey NID RoomServerStateKeyNIDs Cache[string, types.EventStateKeyNID] // event state key -> eventStateKey NID
RoomServerEventTypeNIDs Cache[string, types.EventTypeNID] // eventType -> eventType NID RoomServerEventTypeNIDs Cache[string, types.EventTypeNID] // eventType -> eventType NID
RoomServerEventTypes Cache[types.EventTypeNID, string] // eventType NID -> eventType RoomServerEventTypes Cache[types.EventTypeNID, string] // eventType NID -> eventType
FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU FederationPDUs Cache[int64, *types.HeaderedEvent] // queue NID -> PDU
FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU
SpaceSummaryRooms Cache[string, fclient.MSC2946SpacesResponse] // room ID -> space response SpaceSummaryRooms Cache[string, fclient.MSC2946SpacesResponse] // room ID -> space response
LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID

View file

@ -131,8 +131,8 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm
Prefix: eventTypeNIDCache, Prefix: eventTypeNIDCache,
MaxAge: maxAge, MaxAge: maxAge,
}, },
FederationPDUs: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ // queue NID -> PDU FederationPDUs: &RistrettoCostedCachePartition[int64, *types.HeaderedEvent]{ // queue NID -> PDU
&RistrettoCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ &RistrettoCachePartition[int64, *types.HeaderedEvent]{
cache: cache, cache: cache,
Prefix: federationPDUsCache, Prefix: federationPDUsCache,
Mutable: true, Mutable: true,

View file

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -43,7 +44,7 @@ func QueryAndBuildEvent(
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *fclient.SigningIdentity, evTime time.Time, identity *fclient.SigningIdentity, evTime time.Time,
rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*types.HeaderedEvent, error) {
if queryRes == nil { if queryRes == nil {
queryRes = &api.QueryLatestEventsAndStateResponse{} queryRes = &api.QueryLatestEventsAndStateResponse{}
} }
@ -63,7 +64,7 @@ func BuildEvent(
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *fclient.SigningIdentity, evTime time.Time, identity *fclient.SigningIdentity, evTime time.Time,
eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*types.HeaderedEvent, error) {
if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil { if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil {
return nil, err return nil, err
} }
@ -76,7 +77,7 @@ func BuildEvent(
return nil, err return nil, err
} }
return event.Headered(queryRes.RoomVersion), nil return &types.HeaderedEvent{Event: event}, nil
} }
// queryRequiredEventsForBuilder queries the roomserver for auth/prev events needed for this builder. // queryRequiredEventsForBuilder queries the roomserver for auth/prev events needed for this builder.

View file

@ -21,17 +21,17 @@ import (
) )
const ( const (
// KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent // KindNewEventPersisted is a hook which is called with *types.HeaderedEvent
// It is run when a new event is persisted in the roomserver. // It is run when a new event is persisted in the roomserver.
// Usage: // Usage:
// hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... }) // hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... })
KindNewEventPersisted = "new_event_persisted" KindNewEventPersisted = "new_event_persisted"
// KindNewEventReceived is a hook which is called with *gomatrixserverlib.HeaderedEvent // KindNewEventReceived is a hook which is called with *types.HeaderedEvent
// It is run before a new event is processed by the roomserver. This hook can be used // It is run before a new event is processed by the roomserver. This hook can be used
// to modify the event before it is persisted by adding data to `unsigned`. // to modify the event before it is persisted by adding data to `unsigned`.
// Usage: // Usage:
// hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { // hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
// ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent) // ev := headeredEvent.(*types.HeaderedEvent)
// _ = ev.SetUnsignedField("key", "val") // _ = ev.SetUnsignedField("key", "val")
// }) // })
KindNewEventReceived = "new_event_received" KindNewEventReceived = "new_event_received"

View file

@ -53,7 +53,7 @@ func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluat
// MatchEvent returns the first matching rule. Returns nil if there // MatchEvent returns the first matching rule. Returns nil if there
// was no match rule. // was no match rule.
func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) { func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, error) {
// TODO: server-default rules have lower priority than user rules, // TODO: server-default rules have lower priority than user rules,
// but they are stored together with the user rules. It's a bit // but they are stored together with the user rules. It's a bit
// unclear what the specification (11.14.1.4 Predefined rules) // unclear what the specification (11.14.1.4 Predefined rules)
@ -83,7 +83,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule,
return nil, nil return nil, nil
} }
func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) {
if !rule.Enabled { if !rule.Enabled {
return false, nil return false, nil
} }
@ -120,7 +120,7 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu
} }
} }
func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { func conditionMatches(cond *Condition, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) {
switch cond.Kind { switch cond.Kind {
case EventMatchCondition: case EventMatchCondition:
if cond.Pattern == nil { if cond.Pattern == nil {
@ -150,7 +150,7 @@ func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec Evalua
} }
} }
func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { func patternMatches(key, pattern string, event gomatrixserverlib.PDU) (bool, error) {
// It doesn't make sense for an empty pattern to match anything. // It doesn't make sense for an empty pattern to match anything.
if pattern == "" { if pattern == "" {
return false, nil return false, nil

View file

@ -29,7 +29,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
Name string Name string
RuleSet RuleSet RuleSet RuleSet
Want *Rule Want *Rule
Event *gomatrixserverlib.Event Event gomatrixserverlib.PDU
}{ }{
{"empty", RuleSet{}, nil, ev}, {"empty", RuleSet{}, nil, ev},
{"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled, ev}, {"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled, ev},
@ -188,7 +188,7 @@ func TestPatternMatches(t *testing.T) {
} }
} }
func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event { func mustEventFromJSON(t *testing.T, json string) gomatrixserverlib.PDU {
ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV7).NewEventFromTrustedJSON([]byte(json), false) ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV7).NewEventFromTrustedJSON([]byte(json), false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
syncTypes "github.com/matrix-org/dendrite/syncapi/types" syncTypes "github.com/matrix-org/dendrite/syncapi/types"
userAPI "github.com/matrix-org/dendrite/userapi/api" userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -115,14 +116,13 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
if v, ok := roomVersions[roomID]; ok { if v, ok := roomVersions[roomID]; ok {
return v return v
} }
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} roomVersion, err := t.rsAPI.QueryRoomVersionForRoom(ctx, roomID)
verRes := api.QueryRoomVersionForRoomResponse{} if err != nil {
if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", roomID)
util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID)
return "" return ""
} }
roomVersions[roomID] = verRes.RoomVersion roomVersions[roomID] = roomVersion
return verRes.RoomVersion return roomVersion
} }
for _, pdu := range t.PDUs { for _, pdu := range t.PDUs {
@ -168,7 +168,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
} }
continue continue
} }
if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = fclient.PDUResult{ results[event.EventID()] = fclient.PDUResult{
Error: err.Error(), Error: err.Error(),
@ -183,8 +183,8 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
ctx, ctx,
t.rsAPI, t.rsAPI,
api.KindNew, api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{ []*rstypes.HeaderedEvent{
event.Headered(roomVersion), {Event: event},
}, },
t.Destination, t.Destination,
t.Origin, t.Origin,

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/federationapi/producers"
rsAPI "github.com/matrix-org/dendrite/roomserver/api" rsAPI "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
@ -59,8 +60,8 @@ var (
} }
testEvent = []byte(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`) testEvent = []byte(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`)
testRoomVersion = gomatrixserverlib.RoomVersionV1 testRoomVersion = gomatrixserverlib.RoomVersionV1
testEvents = []*gomatrixserverlib.HeaderedEvent{} testEvents = []*rstypes.HeaderedEvent{}
testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*rstypes.HeaderedEvent)
) )
type FakeRsAPI struct { type FakeRsAPI struct {
@ -72,14 +73,12 @@ type FakeRsAPI struct {
func (r *FakeRsAPI) QueryRoomVersionForRoom( func (r *FakeRsAPI) QueryRoomVersionForRoom(
ctx context.Context, ctx context.Context,
req *rsAPI.QueryRoomVersionForRoomRequest, roomID string,
res *rsAPI.QueryRoomVersionForRoomResponse, ) (gomatrixserverlib.RoomVersion, error) {
) error {
if r.shouldFailQuery { if r.shouldFailQuery {
return fmt.Errorf("Failure") return "", fmt.Errorf("Failure")
} }
res.RoomVersion = gomatrixserverlib.RoomVersionV10 return gomatrixserverlib.RoomVersionV10, nil
return nil
} }
func (r *FakeRsAPI) QueryServerBannedFromRoom( func (r *FakeRsAPI) QueryServerBannedFromRoom(
@ -637,7 +636,7 @@ func init() {
if err != nil { if err != nil {
panic("cannot load test data: " + err.Error()) panic("cannot load test data: " + err.Error())
} }
h := e.Headered(testRoomVersion) h := &rstypes.HeaderedEvent{Event: e}
testEvents = append(testEvents, h) testEvents = append(testEvents, h)
if e.StateKey() != nil { if e.StateKey() != nil {
testStateEvents[gomatrixserverlib.StateKeyTuple{ testStateEvents[gomatrixserverlib.StateKeyTuple{
@ -722,11 +721,9 @@ func (t *testRoomserverAPI) QueryServerJoinedToRoom(
// Asks for the room version for a given room. // Asks for the room version for a given room.
func (t *testRoomserverAPI) QueryRoomVersionForRoom( func (t *testRoomserverAPI) QueryRoomVersionForRoom(
ctx context.Context, ctx context.Context,
request *rsAPI.QueryRoomVersionForRoomRequest, roomID string,
response *rsAPI.QueryRoomVersionForRoomResponse, ) (gomatrixserverlib.RoomVersion, error) {
) error { return testRoomVersion, nil
response.RoomVersion = testRoomVersion
return nil
} }
func (t *testRoomserverAPI) QueryServerBannedFromRoom( func (t *testRoomserverAPI) QueryServerBannedFromRoom(
@ -781,7 +778,7 @@ NextPDU:
} }
} }
func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*rstypes.HeaderedEvent) {
for _, g := range got { for _, g := range got {
fmt.Println("GOT ", g.Event.EventID()) fmt.Println("GOT ", g.Event.EventID())
} }
@ -805,7 +802,7 @@ func TestBasicTransaction(t *testing.T) {
} }
txn := mustCreateTransaction(rsAPI, pdus) txn := mustCreateTransaction(rsAPI, pdus)
mustProcessTransaction(t, txn, nil) mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*rstypes.HeaderedEvent{testEvents[len(testEvents)-1]})
} }
// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver // The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver
@ -818,5 +815,5 @@ func TestTransactionFailAuthChecks(t *testing.T) {
txn := mustCreateTransaction(rsAPI, pdus) txn := mustCreateTransaction(rsAPI, pdus)
mustProcessTransaction(t, txn, []string{}) mustProcessTransaction(t, txn, []string{})
// expect message to be sent to the roomserver // expect message to be sent to the roomserver
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*rstypes.HeaderedEvent{testEvents[len(testEvents)-1]})
} }

View file

@ -101,11 +101,11 @@ func (r *RelayInternalAPI) QueryTransactions(
userID spec.UserID, userID spec.UserID,
previousEntry fclient.RelayEntry, previousEntry fclient.RelayEntry,
) (api.QueryRelayTransactionsResponse, error) { ) (api.QueryRelayTransactionsResponse, error) {
logrus.Infof("QueryTransactions for %s", userID.Raw()) logrus.Infof("QueryTransactions for %s", userID.String())
if previousEntry.EntryID > 0 { if previousEntry.EntryID > 0 {
logrus.Infof("Cleaning previous entry (%v) from db for %s", logrus.Infof("Cleaning previous entry (%v) from db for %s",
previousEntry.EntryID, previousEntry.EntryID,
userID.Raw(), userID.String(),
) )
prevReceipt := receipt.NewReceipt(previousEntry.EntryID) prevReceipt := receipt.NewReceipt(previousEntry.EntryID)
err := r.db.CleanTransactions(ctx, userID, []*receipt.Receipt{&prevReceipt}) err := r.db.CleanTransactions(ctx, userID, []*receipt.Receipt{&prevReceipt})
@ -123,12 +123,12 @@ func (r *RelayInternalAPI) QueryTransactions(
response := api.QueryRelayTransactionsResponse{} response := api.QueryRelayTransactionsResponse{}
if transaction != nil && receipt != nil { if transaction != nil && receipt != nil {
logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, userID.Raw()) logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, userID.String())
response.Transaction = *transaction response.Transaction = *transaction
response.EntryID = receipt.GetNID() response.EntryID = receipt.GetNID()
response.EntriesQueued = true response.EntriesQueued = true
} else { } else {
logrus.Infof("No more entries in the queue for %s", userID.Raw()) logrus.Infof("No more entries in the queue for %s", userID.String())
response.EntryID = 0 response.EntryID = 0
response.EntriesQueued = false response.EntriesQueued = false
} }

View file

@ -34,7 +34,7 @@ func GetTransactionFromRelay(
relayAPI api.RelayInternalAPI, relayAPI api.RelayInternalAPI,
userID spec.UserID, userID spec.UserID,
) util.JSONResponse { ) util.JSONResponse {
logrus.Infof("Processing relay_txn for %s", userID.Raw()) logrus.Infof("Processing relay_txn for %s", userID.String())
var previousEntry fclient.RelayEntry var previousEntry fclient.RelayEntry
if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil {

View file

@ -35,7 +35,7 @@ func createQuery(
prevEntry fclient.RelayEntry, prevEntry fclient.RelayEntry,
) fclient.FederationRequest { ) fclient.FederationRequest {
var federationPathPrefixV1 = "/_matrix/federation/v1" var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/relay_txn/" + userID.Raw() path := federationPathPrefixV1 + "/relay_txn/" + userID.String()
request := fclient.NewFederationRequest("GET", userID.Domain(), "relay", path) request := fclient.NewFederationRequest("GET", userID.Domain(), "relay", path)
request.SetContent(prevEntry) request.SetContent(prevEntry)

View file

@ -36,7 +36,7 @@ func SendTransactionToRelay(
txnID gomatrixserverlib.TransactionID, txnID gomatrixserverlib.TransactionID,
userID spec.UserID, userID spec.UserID,
) util.JSONResponse { ) util.JSONResponse {
logrus.Infof("Processing send_relay for %s", userID.Raw()) logrus.Infof("Processing send_relay for %s", userID.String())
var txnEvents fclient.RelayEvents var txnEvents fclient.RelayEvents
if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil {

View file

@ -52,7 +52,7 @@ func createFederationRequest(
content interface{}, content interface{},
) fclient.FederationRequest { ) fclient.FederationRequest {
var federationPathPrefixV1 = "/_matrix/federation/v1" var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.Raw() path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.String()
request := fclient.NewFederationRequest("PUT", origin, destination, path) request := fclient.NewFederationRequest("PUT", origin, destination, path)
request.SetContent(content) request.SetContent(content)

View file

@ -23,6 +23,7 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -34,7 +35,7 @@ type ServerACLDatabase interface {
// GetStateEvent returns the state event of a given type for a given room with a given state key // GetStateEvent returns the state event of a given type for a given room with a given state key
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
} }
type ServerACLs struct { type ServerACLs struct {

View file

@ -11,6 +11,25 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
// ErrInvalidID is an error returned if the userID is invalid
type ErrInvalidID struct {
Err error
}
func (e ErrInvalidID) Error() string {
return e.Err.Error()
}
// ErrNotAllowed is an error returned if the user is not allowed
// to execute some action (e.g. invite)
type ErrNotAllowed struct {
Err error
}
func (e ErrNotAllowed) Error() string {
return e.Err.Error()
}
// RoomserverInputAPI is used to write events to the room server. // RoomserverInputAPI is used to write events to the room server.
type RoomserverInternalAPI interface { type RoomserverInternalAPI interface {
SyncRoomserverAPI SyncRoomserverAPI
@ -102,7 +121,7 @@ type SyncRoomserverAPI interface {
) error ) error
// QueryMembershipAtEvent queries the memberships at the given events. // QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to a slice of gomatrixserverlib.HeaderedEvent. // Returns a map from eventID to a slice of types.HeaderedEvent.
QueryMembershipAtEvent( QueryMembershipAtEvent(
ctx context.Context, ctx context.Context,
request *QueryMembershipAtEventRequest, request *QueryMembershipAtEventRequest,
@ -143,24 +162,24 @@ type ClientRoomserverAPI interface {
QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error
// QueryKnownUsers returns a list of users that we know about from our joined rooms. // QueryKnownUsers returns a list of users that we know about from our joined rooms.
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
// PerformRoomUpgrade upgrades a room to a newer version // PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error)
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformAdminPurgeRoom(ctx context.Context, req *PerformAdminPurgeRoomRequest, res *PerformAdminPurgeRoomResponse) error PerformAdminPurgeRoom(ctx context.Context, roomID string) error
PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest) error
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error
PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error PerformPublish(ctx context.Context, req *PerformPublishRequest) error
// PerformForget forgets a rooms history for a specific user // PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error
@ -172,8 +191,8 @@ type UserRoomserverAPI interface {
KeyserverRoomserverAPI KeyserverRoomserverAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
} }
type FederationRoomserverAPI interface { type FederationRoomserverAPI interface {
@ -183,7 +202,7 @@ type FederationRoomserverAPI interface {
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID. // which room to use by querying the first events roomID.
@ -202,7 +221,7 @@ type FederationRoomserverAPI interface {
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest) error
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
} }

View file

@ -18,6 +18,7 @@ package api
import ( import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
@ -67,7 +68,7 @@ type InputRoomEvent struct {
// This controls how the event is processed. // This controls how the event is processed.
Kind Kind `json:"kind"` Kind Kind `json:"kind"`
// The event JSON for the event to add. // The event JSON for the event to add.
Event *gomatrixserverlib.HeaderedEvent `json:"event"` Event *types.HeaderedEvent `json:"event"`
// Which server told us about this event. // Which server told us about this event.
Origin spec.ServerName `json:"origin"` Origin spec.ServerName `json:"origin"`
// Whether the state is supplied as a list of event IDs or whether it // Whether the state is supplied as a list of event IDs or whether it

View file

@ -15,6 +15,7 @@
package api package api
import ( import (
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
@ -107,7 +108,7 @@ const (
// prev_events. // prev_events.
type OutputNewRoomEvent struct { type OutputNewRoomEvent struct {
// The Event. // The Event.
Event *gomatrixserverlib.HeaderedEvent `json:"event"` Event *types.HeaderedEvent `json:"event"`
// Does the event completely rewrite the room state? If so, then AddsStateEventIDs // Does the event completely rewrite the room state? If so, then AddsStateEventIDs
// will contain the entire room state. // will contain the entire room state.
RewritesState bool `json:"rewrites_state,omitempty"` RewritesState bool `json:"rewrites_state,omitempty"`
@ -170,8 +171,8 @@ type OutputNewRoomEvent struct {
HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"` HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"`
} }
func (o *OutputNewRoomEvent) NeededStateEventIDs() ([]*gomatrixserverlib.HeaderedEvent, []string) { func (o *OutputNewRoomEvent) NeededStateEventIDs() ([]*types.HeaderedEvent, []string) {
addsStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, 1) addsStateEvents := make([]*types.HeaderedEvent, 0, 1)
missingEventIDs := make([]string, 0, len(o.AddsStateEventIDs)) missingEventIDs := make([]string, 0, len(o.AddsStateEventIDs))
for _, eventID := range o.AddsStateEventIDs { for _, eventID := range o.AddsStateEventIDs {
if eventID != o.Event.EventID() { if eventID != o.Event.EventID() {
@ -194,7 +195,7 @@ func (o *OutputNewRoomEvent) NeededStateEventIDs() ([]*gomatrixserverlib.Headere
// should build their current room state up from OutputNewRoomEvents only. // should build their current room state up from OutputNewRoomEvents only.
type OutputOldRoomEvent struct { type OutputOldRoomEvent struct {
// The Event. // The Event.
Event *gomatrixserverlib.HeaderedEvent `json:"event"` Event *types.HeaderedEvent `json:"event"`
HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"` HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"`
} }
@ -205,7 +206,7 @@ type OutputNewInviteEvent struct {
// The room version of the invited room. // The room version of the invited room.
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
// The "m.room.member" invite event. // The "m.room.member" invite event.
Event *gomatrixserverlib.HeaderedEvent `json:"event"` Event *types.HeaderedEvent `json:"event"`
} }
// An OutputRetireInviteEvent is written whenever an existing invite is no longer // An OutputRetireInviteEvent is written whenever an existing invite is no longer
@ -232,7 +233,7 @@ type OutputRedactedEvent struct {
// The event ID that was redacted // The event ID that was redacted
RedactedEventID string RedactedEventID string
// The value of `unsigned.redacted_because` - the redaction event itself // The value of `unsigned.redacted_because` - the redaction event itself
RedactedBecause *gomatrixserverlib.HeaderedEvent RedactedBecause *types.HeaderedEvent
} }
// An OutputNewPeek is written whenever a user starts peeking into a room // An OutputNewPeek is written whenever a user starts peeking into a room

View file

@ -1,80 +1,11 @@
package api package api
import ( import (
"encoding/json" "github.com/matrix-org/dendrite/roomserver/types"
"fmt"
"net/http"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
)
type PerformErrorCode int
type PerformError struct {
Msg string
RemoteCode int // remote HTTP status code, for PerformErrRemote
Code PerformErrorCode
}
func (p *PerformError) Error() string {
return fmt.Sprintf("%d : %s", p.Code, p.Msg)
}
// JSONResponse maps error codes to suitable HTTP error codes, defaulting to 500.
func (p *PerformError) JSONResponse() util.JSONResponse {
switch p.Code {
case PerformErrorBadRequest:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(p.Msg),
}
case PerformErrorNoRoom:
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(p.Msg),
}
case PerformErrorNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(p.Msg),
}
case PerformErrorNoOperation:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(p.Msg),
}
case PerformErrRemote:
// if the code is 0 then something bad happened and it isn't
// a remote HTTP error being encapsulated, e.g network error to remote.
if p.RemoteCode == 0 {
return util.ErrorResponse(fmt.Errorf("%s", p.Msg))
}
return util.JSONResponse{
Code: p.RemoteCode,
// TODO: Should we assert this is in fact JSON? E.g gjson parse?
JSON: json.RawMessage(p.Msg),
}
default:
return util.ErrorResponse(p)
}
}
const (
// PerformErrorNotAllowed means the user is not allowed to invite/join/etc this room (e.g join_rule:invite or banned)
PerformErrorNotAllowed PerformErrorCode = 1
// PerformErrorBadRequest means the request was wrong in some way (invalid user ID, wrong server, etc)
PerformErrorBadRequest PerformErrorCode = 2
// PerformErrorNoRoom means that the room being joined doesn't exist.
PerformErrorNoRoom PerformErrorCode = 3
// PerformErrorNoOperation means that the request resulted in nothing happening e.g invite->invite or leave->leave.
PerformErrorNoOperation PerformErrorCode = 4
// PerformErrRemote means that the request failed and the PerformError.Msg is the raw remote JSON error response
PerformErrRemote PerformErrorCode = 5
) )
type PerformJoinRequest struct { type PerformJoinRequest struct {
@ -86,14 +17,6 @@ type PerformJoinRequest struct {
Unsigned map[string]interface{} `json:"unsigned"` Unsigned map[string]interface{} `json:"unsigned"`
} }
type PerformJoinResponse struct {
// The room ID, populated on success.
RoomID string `json:"room_id"`
JoinedVia spec.ServerName
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
type PerformLeaveRequest struct { type PerformLeaveRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
@ -105,15 +28,11 @@ type PerformLeaveResponse struct {
} }
type PerformInviteRequest struct { type PerformInviteRequest struct {
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
Event *gomatrixserverlib.HeaderedEvent `json:"event"` Event *types.HeaderedEvent `json:"event"`
InviteRoomState []fclient.InviteV2StrippedState `json:"invite_room_state"` InviteRoomState []fclient.InviteV2StrippedState `json:"invite_room_state"`
SendAsServer string `json:"send_as_server"` SendAsServer string `json:"send_as_server"`
TransactionID *TransactionID `json:"transaction_id"` TransactionID *TransactionID `json:"transaction_id"`
}
type PerformInviteResponse struct {
Error *PerformError
} }
type PerformPeekRequest struct { type PerformPeekRequest struct {
@ -123,24 +42,6 @@ type PerformPeekRequest struct {
ServerNames []spec.ServerName `json:"server_names"` ServerNames []spec.ServerName `json:"server_names"`
} }
type PerformPeekResponse struct {
// The room ID, populated on success.
RoomID string `json:"room_id"`
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
type PerformUnpeekRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
}
type PerformUnpeekResponse struct {
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
// PerformBackfillRequest is a request to PerformBackfill. // PerformBackfillRequest is a request to PerformBackfill.
type PerformBackfillRequest struct { type PerformBackfillRequest struct {
// The room to backfill // The room to backfill
@ -168,7 +69,7 @@ func (r *PerformBackfillRequest) PrevEventIDs() []string {
// PerformBackfillResponse is a response to PerformBackfill. // PerformBackfillResponse is a response to PerformBackfill.
type PerformBackfillResponse struct { type PerformBackfillResponse struct {
// Missing events, arbritrary order. // Missing events, arbritrary order.
Events []*gomatrixserverlib.HeaderedEvent `json:"events"` Events []*types.HeaderedEvent `json:"events"`
HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"` HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"`
} }
@ -179,11 +80,6 @@ type PerformPublishRequest struct {
NetworkID string NetworkID string
} }
type PerformPublishResponse struct {
// If non-nil, the publish request failed. Contains more information why it failed.
Error *PerformError
}
type PerformInboundPeekRequest struct { type PerformInboundPeekRequest struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
@ -200,10 +96,10 @@ type PerformInboundPeekResponse struct {
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
// The current state and auth chain events. // The current state and auth chain events.
// The lists will be in an arbitrary order. // The lists will be in an arbitrary order.
StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` StateEvents []*types.HeaderedEvent `json:"state_events"`
AuthChainEvents []*gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"` AuthChainEvents []*types.HeaderedEvent `json:"auth_chain_events"`
// The event at which this state was captured // The event at which this state was captured
LatestEvent *gomatrixserverlib.HeaderedEvent `json:"latest_event"` LatestEvent *types.HeaderedEvent `json:"latest_event"`
} }
// PerformForgetRequest is a request to PerformForget // PerformForgetRequest is a request to PerformForget
@ -213,50 +109,3 @@ type PerformForgetRequest struct {
} }
type PerformForgetResponse struct{} type PerformForgetResponse struct{}
type PerformRoomUpgradeRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
}
type PerformRoomUpgradeResponse struct {
NewRoomID string
Error *PerformError
}
type PerformAdminEvacuateRoomRequest struct {
RoomID string `json:"room_id"`
}
type PerformAdminEvacuateRoomResponse struct {
Affected []string `json:"affected"`
Error *PerformError
}
type PerformAdminEvacuateUserRequest struct {
UserID string `json:"user_id"`
}
type PerformAdminEvacuateUserResponse struct {
Affected []string `json:"affected"`
Error *PerformError
}
type PerformAdminPurgeRoomRequest struct {
RoomID string `json:"room_id"`
}
type PerformAdminPurgeRoomResponse struct {
Error *PerformError `json:"error,omitempty"`
}
type PerformAdminDownloadStateRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
ServerName spec.ServerName `json:"server_name"`
}
type PerformAdminDownloadStateResponse struct {
Error *PerformError `json:"error,omitempty"`
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
) )
@ -53,7 +54,7 @@ type QueryLatestEventsAndStateResponse struct {
// This list will be in an arbitrary order. // This list will be in an arbitrary order.
// These are used to set the auth_events when sending an event. // These are used to set the auth_events when sending an event.
// These are used to check whether the event is allowed. // These are used to check whether the event is allowed.
StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` StateEvents []*types.HeaderedEvent `json:"state_events"`
// The depth of the latest events. // The depth of the latest events.
// This is one greater than the maximum depth of the latest events. // This is one greater than the maximum depth of the latest events.
// This is used to set the depth when sending an event. // This is used to set the depth when sending an event.
@ -83,7 +84,7 @@ type QueryStateAfterEventsResponse struct {
PrevEventsExist bool `json:"prev_events_exist"` PrevEventsExist bool `json:"prev_events_exist"`
// The state events requested. // The state events requested.
// This list will be in an arbitrary order. // This list will be in an arbitrary order.
StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` StateEvents []*types.HeaderedEvent `json:"state_events"`
} }
// QueryEventsByIDRequest is a request to QueryEventsByID // QueryEventsByIDRequest is a request to QueryEventsByID
@ -104,7 +105,7 @@ type QueryEventsByIDResponse struct {
// fails to read it from the database then it will fail // fails to read it from the database then it will fail
// the entire request. // the entire request.
// This list will be in an arbitrary order. // This list will be in an arbitrary order.
Events []*gomatrixserverlib.HeaderedEvent `json:"events"` Events []*types.HeaderedEvent `json:"events"`
} }
// QueryMembershipForUserRequest is a request to QueryMembership // QueryMembershipForUserRequest is a request to QueryMembership
@ -202,7 +203,7 @@ type QueryMissingEventsRequest struct {
// QueryMissingEventsResponse is a response to QueryMissingEvents // QueryMissingEventsResponse is a response to QueryMissingEvents
type QueryMissingEventsResponse struct { type QueryMissingEventsResponse struct {
// Missing events, arbritrary order. // Missing events, arbritrary order.
Events []*gomatrixserverlib.HeaderedEvent `json:"events"` Events []*types.HeaderedEvent `json:"events"`
} }
// QueryStateAndAuthChainRequest is a request to QueryStateAndAuthChain // QueryStateAndAuthChainRequest is a request to QueryStateAndAuthChain
@ -236,8 +237,8 @@ type QueryStateAndAuthChainResponse struct {
StateKnown bool `json:"state_known"` StateKnown bool `json:"state_known"`
// The state and auth chain events that were requested. // The state and auth chain events that were requested.
// The lists will be in an arbitrary order. // The lists will be in an arbitrary order.
StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` StateEvents []*types.HeaderedEvent `json:"state_events"`
AuthChainEvents []*gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"` AuthChainEvents []*types.HeaderedEvent `json:"auth_chain_events"`
// True if the queried event was rejected earlier. // True if the queried event was rejected earlier.
IsRejected bool `json:"is_rejected"` IsRejected bool `json:"is_rejected"`
} }
@ -269,7 +270,7 @@ type QueryAuthChainRequest struct {
} }
type QueryAuthChainResponse struct { type QueryAuthChainResponse struct {
AuthChain []*gomatrixserverlib.HeaderedEvent AuthChain []*types.HeaderedEvent
} }
type QuerySharedUsersRequest struct { type QuerySharedUsersRequest struct {
@ -327,7 +328,7 @@ type QueryCurrentStateRequest struct {
} }
type QueryCurrentStateResponse struct { type QueryCurrentStateResponse struct {
StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent StateEvents map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent
} }
type QueryKnownUsersRequest struct { type QueryKnownUsersRequest struct {
@ -404,7 +405,7 @@ func (r *QueryBulkStateContentResponse) UnmarshalJSON(data []byte) error {
// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. // MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) {
se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) se := make(map[string]*types.HeaderedEvent, len(r.StateEvents))
for k, v := range r.StateEvents { for k, v := range r.StateEvents {
// use 0x1F (unit separator) as the delimiter between type/state key, // use 0x1F (unit separator) as the delimiter between type/state key,
se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v
@ -413,12 +414,12 @@ func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) {
} }
func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
res := make(map[string]*gomatrixserverlib.HeaderedEvent) res := make(map[string]*types.HeaderedEvent)
err := json.Unmarshal(data, &res) err := json.Unmarshal(data, &res)
if err != nil { if err != nil {
return err return err
} }
r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res)) r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(res))
for k, v := range res { for k, v := range res {
fields := strings.Split(k, "\x1F") fields := strings.Split(k, "\x1F")
r.StateEvents[gomatrixserverlib.StateKeyTuple{ r.StateEvents[gomatrixserverlib.StateKeyTuple{
@ -442,7 +443,7 @@ type QueryMembershipAtEventResponse struct {
// Membership is a map from eventID to membership event. Events that // Membership is a map from eventID to membership event. Events that
// do not have known state will return a nil event, resulting in a "leave" membership // do not have known state will return a nil event, resulting in a "leave" membership
// when calculating history visibility. // when calculating history visibility.
Membership map[string]*gomatrixserverlib.HeaderedEvent `json:"membership"` Membership map[string]*types.HeaderedEvent `json:"membership"`
} }
// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a // QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a

View file

@ -17,6 +17,7 @@ package api
import ( import (
"context" "context"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -27,7 +28,7 @@ import (
// SendEvents to the roomserver The events are written with KindNew. // SendEvents to the roomserver The events are written with KindNew.
func SendEvents( func SendEvents(
ctx context.Context, rsAPI InputRoomEventsAPI, ctx context.Context, rsAPI InputRoomEventsAPI,
kind Kind, events []*gomatrixserverlib.HeaderedEvent, kind Kind, events []*types.HeaderedEvent,
virtualHost, origin spec.ServerName, virtualHost, origin spec.ServerName,
sendAsServer spec.ServerName, txnID *TransactionID, sendAsServer spec.ServerName, txnID *TransactionID,
async bool, async bool,
@ -51,10 +52,11 @@ func SendEvents(
func SendEventWithState( func SendEventWithState(
ctx context.Context, rsAPI InputRoomEventsAPI, ctx context.Context, rsAPI InputRoomEventsAPI,
virtualHost spec.ServerName, kind Kind, virtualHost spec.ServerName, kind Kind,
state gomatrixserverlib.StateResponse, event *gomatrixserverlib.HeaderedEvent, state gomatrixserverlib.StateResponse, event *types.HeaderedEvent,
origin spec.ServerName, haveEventIDs map[string]bool, async bool, origin spec.ServerName, haveEventIDs map[string]bool, async bool,
) error { ) error {
outliers := gomatrixserverlib.LineariseStateResponse(event.RoomVersion, state) outliersPDU := gomatrixserverlib.LineariseStateResponse(event.Version(), state)
outliers := gomatrixserverlib.TempCastToEvents(outliersPDU)
ires := make([]InputRoomEvent, 0, len(outliers)) ires := make([]InputRoomEvent, 0, len(outliers))
for _, outlier := range outliers { for _, outlier := range outliers {
if haveEventIDs[outlier.EventID()] { if haveEventIDs[outlier.EventID()] {
@ -62,12 +64,12 @@ func SendEventWithState(
} }
ires = append(ires, InputRoomEvent{ ires = append(ires, InputRoomEvent{
Kind: KindOutlier, Kind: KindOutlier,
Event: outlier.Headered(event.RoomVersion), Event: &types.HeaderedEvent{Event: outlier},
Origin: origin, Origin: origin,
}) })
} }
stateEvents := state.GetStateEvents().UntrustedEvents(event.RoomVersion) stateEvents := state.GetStateEvents().UntrustedEvents(event.Version())
stateEventIDs := make([]string, len(stateEvents)) stateEventIDs := make([]string, len(stateEvents))
for i := range stateEvents { for i := range stateEvents {
stateEventIDs[i] = stateEvents[i].EventID() stateEventIDs[i] = stateEvents[i].EventID()
@ -110,7 +112,7 @@ func SendInputRoomEvents(
} }
// GetEvent returns the event or nil, even on errors. // GetEvent returns the event or nil, even on errors.
func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *gomatrixserverlib.HeaderedEvent { func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *types.HeaderedEvent {
var res QueryEventsByIDResponse var res QueryEventsByIDResponse
err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{ err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{
RoomID: roomID, RoomID: roomID,
@ -127,7 +129,7 @@ func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string)
} }
// GetStateEvent returns the current state event in the room or nil. // GetStateEvent returns the current state event in the room or nil.
func GetStateEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent { func GetStateEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *types.HeaderedEvent {
var res QueryCurrentStateResponse var res QueryCurrentStateResponse
err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{ err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{
RoomID: roomID, RoomID: roomID,

View file

@ -26,6 +26,10 @@ func IsServerAllowed(
serverCurrentlyInRoom bool, serverCurrentlyInRoom bool,
authEvents []*gomatrixserverlib.Event, authEvents []*gomatrixserverlib.Event,
) bool { ) bool {
// In practice should not happen, but avoids unneeded CPU cycles
if serverName == "" || len(authEvents) == 0 {
return false
}
historyVisibility := HistoryVisibilityForRoom(authEvents) historyVisibility := HistoryVisibilityForRoom(authEvents)
// 1. If the history_visibility was set to world_readable, allow. // 1. If the history_visibility was set to world_readable, allow.

View file

@ -0,0 +1,85 @@
package auth
import (
"testing"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
func TestIsServerAllowed(t *testing.T) {
alice := test.NewUser(t)
tests := []struct {
name string
want bool
roomFunc func() *test.Room
serverName spec.ServerName
serverCurrentlyInRoom bool
}{
{
name: "no servername specified",
roomFunc: func() *test.Room { return test.NewRoom(t, alice) },
},
{
name: "no authEvents specified",
serverName: "test",
roomFunc: func() *test.Room { return &test.Room{} },
},
{
name: "default denied",
serverName: "test2",
roomFunc: func() *test.Room { return test.NewRoom(t, alice) },
},
{
name: "world readable room",
serverName: "test",
roomFunc: func() *test.Room {
return test.NewRoom(t, alice, test.RoomHistoryVisibility(gomatrixserverlib.HistoryVisibilityWorldReadable))
},
want: true,
},
{
name: "allowed due to alice being joined",
serverName: "test",
roomFunc: func() *test.Room { return test.NewRoom(t, alice) },
want: true,
},
{
name: "allowed due to 'serverCurrentlyInRoom'",
serverName: "test2",
roomFunc: func() *test.Room { return test.NewRoom(t, alice) },
want: true,
serverCurrentlyInRoom: true,
},
{
name: "allowed due to pending invite",
serverName: "test2",
roomFunc: func() *test.Room {
bob := test.User{ID: "@bob:test2"}
r := test.NewRoom(t, alice, test.RoomHistoryVisibility(gomatrixserverlib.HistoryVisibilityInvited))
r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
"membership": spec.Invite,
}, test.WithStateKey(bob.ID))
return r
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.roomFunc == nil {
t.Fatalf("missing roomFunc")
}
var authEvents []*gomatrixserverlib.Event
for _, ev := range tt.roomFunc().Events() {
authEvents = append(authEvents, ev.Event)
}
if got := IsServerAllowed(tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@ -140,7 +141,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
} }
if creatorID != request.UserID { if creatorID != request.UserID {
var plEvent *gomatrixserverlib.HeaderedEvent var plEvent *types.HeaderedEvent
var pls *gomatrixserverlib.PowerLevelContent var pls *gomatrixserverlib.PowerLevelContent
plEvent, err = r.DB.GetStateEvent(ctx, roomID, spec.MRoomPowerLevels, "") plEvent, err = r.DB.GetStateEvent(ctx, roomID, spec.MRoomPowerLevels, "")
@ -212,7 +213,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return err return err
} }
err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false) err = api.SendEvents(ctx, r, api.KindNew, []*types.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false)
if err != nil { if err != nil {
return err return err
} }

View file

@ -209,11 +209,9 @@ func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalA
func (r *RoomserverInternalAPI) PerformInvite( func (r *RoomserverInternalAPI) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
res *api.PerformInviteResponse,
) error { ) error {
outputEvents, err := r.Inviter.PerformInvite(ctx, req, res) outputEvents, err := r.Inviter.PerformInvite(ctx, req)
if err != nil { if err != nil {
sentry.CaptureException(err)
return err return err
} }
if len(outputEvents) == 0 { if len(outputEvents) == 0 {

View file

@ -34,7 +34,7 @@ func CheckForSoftFail(
ctx context.Context, ctx context.Context,
db storage.RoomDatabase, db storage.RoomDatabase,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
event *gomatrixserverlib.HeaderedEvent, event *types.HeaderedEvent,
stateEventIDs []string, stateEventIDs []string,
) (bool, error) { ) (bool, error) {
rewritesState := len(stateEventIDs) > 1 rewritesState := len(stateEventIDs) > 1
@ -65,7 +65,9 @@ func CheckForSoftFail(
} }
// Work out which of the state events we actually need. // Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth(
gomatrixserverlib.ToPDUs([]*gomatrixserverlib.Event{event.Event}),
)
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
@ -87,7 +89,7 @@ func CheckAuthEvents(
ctx context.Context, ctx context.Context,
db storage.RoomDatabase, db storage.RoomDatabase,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
event *gomatrixserverlib.HeaderedEvent, event *types.HeaderedEvent,
authEventIDs []string, authEventIDs []string,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
// Grab the numeric IDs for the supplied auth state events from the database. // Grab the numeric IDs for the supplied auth state events from the database.
@ -98,7 +100,7 @@ func CheckAuthEvents(
authStateEntries = types.DeduplicateStateEntries(authStateEntries) authStateEntries = types.DeduplicateStateEntries(authStateEntries)
// Work out which of the state events we actually need. // Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event.Event})
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
@ -132,31 +134,31 @@ func (ae *authEvents) Valid() bool {
} }
// Create implements gomatrixserverlib.AuthEventProvider // Create implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) { func (ae *authEvents) Create() (gomatrixserverlib.PDU, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil
} }
// PowerLevels implements gomatrixserverlib.AuthEventProvider // PowerLevels implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) { func (ae *authEvents) PowerLevels() (gomatrixserverlib.PDU, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil
} }
// JoinRules implements gomatrixserverlib.AuthEventProvider // JoinRules implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) { func (ae *authEvents) JoinRules() (gomatrixserverlib.PDU, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil
} }
// Memmber implements gomatrixserverlib.AuthEventProvider // Memmber implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) { func (ae *authEvents) Member(stateKey string) (gomatrixserverlib.PDU, error) {
return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil
} }
// ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider // ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) { func (ae *authEvents) ThirdPartyInvite(stateKey string) (gomatrixserverlib.PDU, error) {
return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
} }
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event { func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) gomatrixserverlib.PDU {
eventNID, ok := ae.state.lookup(types.StateKeyTuple{ eventNID, ok := ae.state.lookup(types.StateKeyTuple{
EventTypeNID: typeNID, EventTypeNID: typeNID,
EventStateKeyNID: types.EmptyStateKeyNID, EventStateKeyNID: types.EmptyStateKeyNID,
@ -171,7 +173,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *
return event.Event return event.Event
} }
func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event { func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) gomatrixserverlib.PDU {
stateKeyNID, ok := ae.stateKeyNIDMap[stateKey] stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
if !ok { if !ok {
return nil return nil

View file

@ -45,7 +45,7 @@ func UpdateToInviteMembership(
updates = append(updates, api.OutputEvent{ updates = append(updates, api.OutputEvent{
Type: api.OutputTypeNewInviteEvent, Type: api.OutputTypeNewInviteEvent,
NewInviteEvent: &api.OutputNewInviteEvent{ NewInviteEvent: &api.OutputNewInviteEvent{
Event: add.Headered(roomVersion), Event: &types.HeaderedEvent{Event: add.Event},
RoomVersion: roomVersion, RoomVersion: roomVersion,
}, },
}) })
@ -479,7 +479,7 @@ func QueryLatestEventsAndState(
} }
for _, event := range stateEvents { for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion)) response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event})
} }
return nil return nil

View file

@ -41,7 +41,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
var authNIDs []types.EventNID var authNIDs []types.EventNID
for _, x := range room.Events() { for _, x := range room.Events() {
roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.Unwrap()) roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.Event)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, roomInfo) assert.NotNil(t, roomInfo)

View file

@ -102,7 +102,7 @@ func (r *Inputer) processRoomEvent(
// Parse and validate the event JSON // Parse and validate the event JSON
headered := input.Event headered := input.Event
event := headered.Unwrap() event := headered.Event
logger := util.GetLogger(ctx).WithFields(logrus.Fields{ logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": event.EventID(), "event_id": event.EventID(),
"room_id": event.RoomID(), "room_id": event.RoomID(),
@ -235,7 +235,7 @@ func (r *Inputer) processRoomEvent(
haveEvents: map[string]*gomatrixserverlib.Event{}, haveEvents: map[string]*gomatrixserverlib.Event{},
} }
var stateSnapshot *parsedRespState var stateSnapshot *parsedRespState
if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.Version()); err != nil {
// Something went wrong with retrieving the missing state, so we can't // Something went wrong with retrieving the missing state, so we can't
// really do anything with the event other than reject it at this point. // really do anything with the event other than reject it at this point.
isRejected = true isRejected = true
@ -467,7 +467,7 @@ func (r *Inputer) processRoomEvent(
Type: api.OutputTypeRedactedEvent, Type: api.OutputTypeRedactedEvent,
RedactedEvent: &api.OutputRedactedEvent{ RedactedEvent: &api.OutputRedactedEvent{
RedactedEventID: redactedEventID, RedactedEventID: redactedEventID,
RedactedBecause: redactionEvent.Headered(headered.RoomVersion), RedactedBecause: &types.HeaderedEvent{Event: redactionEvent},
}, },
}, },
}) })
@ -478,7 +478,7 @@ func (r *Inputer) processRoomEvent(
// If guest_access changed and is not can_join, kick all guest users. // If guest_access changed and is not can_join, kick all guest users.
if event.Type() == spec.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { if event.Type() == spec.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
if err = r.kickGuests(ctx, event, roomInfo); err != nil { if err = r.kickGuests(ctx, event, roomInfo); err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation") logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
} }
} }
@ -509,7 +509,7 @@ func (r *Inputer) processStateBefore(
missingPrev bool, missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared. historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared.
event := input.Event.Unwrap() event := input.Event.Event
isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("") isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("")
var stateBeforeEvent []*gomatrixserverlib.Event var stateBeforeEvent []*gomatrixserverlib.Event
switch { switch {
@ -545,7 +545,7 @@ func (r *Inputer) processStateBefore(
// will include the history visibility here even though we don't // will include the history visibility here even though we don't
// actually need it for auth, because we want to send it in the // actually need it for auth, because we want to send it in the
// output events. // output events.
tuplesNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event}).Tuples() tuplesNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event}).Tuples()
tuplesNeeded = append(tuplesNeeded, gomatrixserverlib.StateKeyTuple{ tuplesNeeded = append(tuplesNeeded, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomHistoryVisibility, EventType: spec.MRoomHistoryVisibility,
StateKey: "", StateKey: "",
@ -567,14 +567,20 @@ func (r *Inputer) processStateBefore(
rejectionErr = fmt.Errorf("prev events of %q are not known", event.EventID()) rejectionErr = fmt.Errorf("prev events of %q are not known", event.EventID())
return return
default: default:
stateBeforeEvent = gomatrixserverlib.UnwrapEventHeaders(stateBeforeRes.StateEvents) stateBeforeEvent = make([]*gomatrixserverlib.Event, len(stateBeforeRes.StateEvents))
for i := range stateBeforeRes.StateEvents {
stateBeforeEvent[i] = stateBeforeRes.StateEvents[i].Event
}
} }
} }
// At this point, stateBeforeEvent should be populated either by // At this point, stateBeforeEvent should be populated either by
// the supplied state in the input request, or from the prev events. // the supplied state in the input request, or from the prev events.
// Check whether the event is allowed or not. // Check whether the event is allowed or not.
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(stateBeforeEvent) stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent),
)
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil { if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
return return
} }
// Work out what the history visibility was at the time of the // Work out what the history visibility was at the time of the
@ -604,7 +610,7 @@ func (r *Inputer) fetchAuthEvents(
logger *logrus.Entry, logger *logrus.Entry,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
virtualHost spec.ServerName, virtualHost spec.ServerName,
event *gomatrixserverlib.HeaderedEvent, event *types.HeaderedEvent,
auth *gomatrixserverlib.AuthEvents, auth *gomatrixserverlib.AuthEvents,
known map[string]*types.Event, known map[string]*types.Event,
servers []spec.ServerName, servers []spec.ServerName,
@ -654,7 +660,7 @@ func (r *Inputer) fetchAuthEvents(
// Request the entire auth chain for the event in question. This should // Request the entire auth chain for the event in question. This should
// contain all of the auth events — including ones that we already know — // contain all of the auth events — including ones that we already know —
// so we'll need to filter through those in the next section. // so we'll need to filter through those in the next section.
res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.RoomVersion, event.RoomID(), event.EventID()) res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.Version(), event.RoomID(), event.EventID())
if err != nil { if err != nil {
logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err)
continue continue
@ -671,7 +677,7 @@ func (r *Inputer) fetchAuthEvents(
isRejected := false isRejected := false
nextAuthEvent: nextAuthEvent:
for _, authEvent := range gomatrixserverlib.ReverseTopologicalOrdering( for _, authEvent := range gomatrixserverlib.ReverseTopologicalOrdering(
res.AuthEvents.UntrustedEvents(event.RoomVersion), gomatrixserverlib.ToPDUs(res.AuthEvents.UntrustedEvents(event.Version())),
gomatrixserverlib.TopologicalOrderByAuthEvents, gomatrixserverlib.TopologicalOrderByAuthEvents,
) { ) {
// If we already know about this event from the database then we don't // If we already know about this event from the database then we don't
@ -684,7 +690,7 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply // Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem // skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway. // if a critical event is missing anyway.
if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil {
continue nextAuthEvent continue nextAuthEvent
} }
@ -739,7 +745,7 @@ nextAuthEvent:
// Now we know about this event, it was stored and the signatures were OK. // Now we know about this event, it was stored and the signatures were OK.
known[authEvent.EventID()] = &types.Event{ known[authEvent.EventID()] = &types.Event{
EventNID: eventNID, EventNID: eventNID,
Event: authEvent, Event: authEvent.(*gomatrixserverlib.Event),
} }
} }

View file

@ -393,7 +393,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
} }
ore := api.OutputNewRoomEvent{ ore := api.OutputNewRoomEvent{
Event: u.event.Headered(u.roomInfo.RoomVersion), Event: &types.HeaderedEvent{Event: u.event},
RewritesState: u.rewritesState, RewritesState: u.rewritesState,
LastSentEventID: u.lastEventIDSent, LastSentEventID: u.lastEventIDSent,
LatestEventIDs: latestEventIDs, LatestEventIDs: latestEventIDs,

View file

@ -26,7 +26,7 @@ type parsedRespState struct {
StateEvents []*gomatrixserverlib.Event StateEvents []*gomatrixserverlib.Event
} }
func (p *parsedRespState) Events() []*gomatrixserverlib.Event { func (p *parsedRespState) Events() []gomatrixserverlib.PDU {
eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents)) eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents))
for i, event := range p.AuthEvents { for i, event := range p.AuthEvents {
eventsByID[event.EventID()] = p.AuthEvents[i] eventsByID[event.EventID()] = p.AuthEvents[i]
@ -38,7 +38,8 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event {
for _, event := range eventsByID { for _, event := range eventsByID {
allEvents = append(allEvents, event) allEvents = append(allEvents, event)
} }
return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents) return gomatrixserverlib.ReverseTopologicalOrdering(
gomatrixserverlib.ToPDUs(allEvents), gomatrixserverlib.TopologicalOrderByAuthEvents)
} }
type missingStateReq struct { type missingStateReq struct {
@ -106,7 +107,7 @@ func (t *missingStateReq) processEventWithMissingState(
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: &types.HeaderedEvent{Event: newEvent},
Origin: t.origin, Origin: t.origin,
SendAsServer: api.DoNotSendToOtherServers, SendAsServer: api.DoNotSendToOtherServers,
}) })
@ -155,7 +156,7 @@ func (t *missingStateReq) processEventWithMissingState(
} }
outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{
Kind: api.KindOutlier, Kind: api.KindOutlier,
Event: outlier.Headered(roomVersion), Event: &types.HeaderedEvent{Event: outlier.(*gomatrixserverlib.Event)},
Origin: t.origin, Origin: t.origin,
}) })
} }
@ -185,7 +186,7 @@ func (t *missingStateReq) processEventWithMissingState(
err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: backwardsExtremity.Headered(roomVersion), Event: &types.HeaderedEvent{Event: backwardsExtremity},
Origin: t.origin, Origin: t.origin,
HasState: true, HasState: true,
StateEventIDs: stateIDs, StateEventIDs: stateIDs,
@ -204,7 +205,7 @@ func (t *missingStateReq) processEventWithMissingState(
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: &types.HeaderedEvent{Event: newEvent},
Origin: t.origin, Origin: t.origin,
SendAsServer: api.DoNotSendToOtherServers, SendAsServer: api.DoNotSendToOtherServers,
}) })
@ -468,7 +469,9 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
authEventList = append(authEventList, state.AuthEvents...) authEventList = append(authEventList, state.AuthEvents...)
stateEventList = append(stateEventList, state.StateEvents...) stateEventList = append(stateEventList, state.StateEvents...)
} }
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -482,7 +485,7 @@ retryAllowedState:
case verifySigError: case verifySigError:
return &parsedRespState{ return &parsedRespState{
AuthEvents: authEventList, AuthEvents: authEventList,
StateEvents: resolvedStateEvents, StateEvents: gomatrixserverlib.TempCastToEvents(resolvedStateEvents),
}, nil }, nil
case nil: case nil:
// do nothing // do nothing
@ -498,7 +501,7 @@ retryAllowedState:
} }
return &parsedRespState{ return &parsedRespState{
AuthEvents: authEventList, AuthEvents: authEventList,
StateEvents: resolvedStateEvents, StateEvents: gomatrixserverlib.TempCastToEvents(resolvedStateEvents),
}, nil }, nil
} }
@ -559,7 +562,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
// will be added and duplicates will be removed. // will be added and duplicates will be removed.
missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events)) missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = ev.VerifyEventSignatures(ctx, t.keys); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil {
continue continue
} }
missingEvents = append(missingEvents, t.cacheAndReturn(ev)) missingEvents = append(missingEvents, t.cacheAndReturn(ev))
@ -567,7 +570,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
logger.Debugf("get_missing_events returned %d events (%d passed signature checks)", len(missingResp.Events), len(missingEvents)) logger.Debugf("get_missing_events returned %d events (%d passed signature checks)", len(missingResp.Events), len(missingEvents))
// topologically sort and sanity check that we are making forward progress // topologically sort and sanity check that we are making forward progress
newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingEvents, gomatrixserverlib.TopologicalOrderByPrevEvents) newEventsPDUs := gomatrixserverlib.ReverseTopologicalOrdering(
gomatrixserverlib.ToPDUs(missingEvents), gomatrixserverlib.TopologicalOrderByPrevEvents)
newEvents = gomatrixserverlib.TempCastToEvents(newEventsPDUs)
shouldHaveSomeEventIDs := e.PrevEventIDs() shouldHaveSomeEventIDs := e.PrevEventIDs()
hasPrevEvent := false hasPrevEvent := false
Event: Event:
@ -882,14 +887,14 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
} }
if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }
return t.cacheAndReturn(event), nil return t.cacheAndReturn(event), nil
} }
func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []gomatrixserverlib.PDU) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil) authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents { for i := range stateEvents {
err := authUsingState.AddEvent(stateEvents[i]) err := authUsingState.AddEvent(stateEvents[i])

View file

@ -10,6 +10,7 @@ import (
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
@ -44,7 +45,7 @@ func TestSingleTransactionOnInput(t *testing.T) {
} }
in := api.InputRoomEvent{ in := api.InputRoomEvent{
Kind: api.KindOutlier, // don't panic if we generate an output event Kind: api.KindOutlier, // don't panic if we generate an output event
Event: event.Headered(gomatrixserverlib.RoomVersionV6), Event: &types.HeaderedEvent{Event: event},
} }
inputter := &input.Inputer{ inputter := &input.Inputer{

View file

@ -26,8 +26,10 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -40,61 +42,44 @@ type Admin struct {
Leaver *Leaver Leaver *Leaver
} }
// PerformEvacuateRoom will remove all local users from the given room. // PerformAdminEvacuateRoom will remove all local users from the given room.
func (r *Admin) PerformAdminEvacuateRoom( func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context, ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest, roomID string,
res *api.PerformAdminEvacuateRoomResponse, ) (affected []string, err error) {
) error { roomInfo, err := r.DB.RoomInfo(ctx, roomID)
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
return nil
} }
if roomInfo == nil || roomInfo.IsStub() { if roomInfo == nil || roomInfo.IsStub() {
res.Error = &api.PerformError{ return nil, eventutil.ErrRoomNoExists
Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room %s not found", req.RoomID),
}
return nil
} }
memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err),
}
return nil
} }
memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs) memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.Events: %s", err),
}
return nil
} }
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
res.Affected = make([]string, 0, len(memberEvents)) affected = make([]string, 0, len(memberEvents))
latestReq := &api.QueryLatestEventsAndStateRequest{ latestReq := &api.QueryLatestEventsAndStateRequest{
RoomID: req.RoomID, RoomID: roomID,
} }
latestRes := &api.QueryLatestEventsAndStateResponse{} latestRes := &api.QueryLatestEventsAndStateResponse{}
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err),
}
return nil
} }
prevEvents := latestRes.LatestEvents prevEvents := latestRes.LatestEvents
var senderDomain spec.ServerName
var eventsNeeded gomatrixserverlib.StateNeeded
var identity *fclient.SigningIdentity
var event *types.HeaderedEvent
for _, memberEvent := range memberEvents { for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil { if memberEvent.StateKey() == nil {
continue continue
@ -102,57 +87,41 @@ func (r *Admin) PerformAdminEvacuateRoom(
var memberContent gomatrixserverlib.MemberContent var memberContent gomatrixserverlib.MemberContent
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Unmarshal: %s", err),
}
return nil
} }
memberContent.Membership = spec.Leave memberContent.Membership = spec.Leave
stateKey := *memberEvent.StateKey() stateKey := *memberEvent.StateKey()
fledglingEvent := &gomatrixserverlib.EventBuilder{ fledglingEvent := &gomatrixserverlib.EventBuilder{
RoomID: req.RoomID, RoomID: roomID,
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: &stateKey, StateKey: &stateKey,
Sender: stateKey, Sender: stateKey,
PrevEvents: prevEvents, PrevEvents: prevEvents,
} }
_, senderDomain, err := gomatrixserverlib.SplitID('@', fledglingEvent.Sender) _, senderDomain, err = gomatrixserverlib.SplitID('@', fledglingEvent.Sender)
if err != nil { if err != nil {
continue continue
} }
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Marshal: %s", err),
}
return nil
} }
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) eventsNeeded, err = gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
return nil
} }
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err = r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
continue continue
} }
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes) event, err = eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
return nil
} }
inputEvents = append(inputEvents, api.InputRoomEvent{ inputEvents = append(inputEvents, api.InputRoomEvent{
@ -161,7 +130,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Origin: senderDomain, Origin: senderDomain,
SendAsServer: string(senderDomain), SendAsServer: string(senderDomain),
}) })
res.Affected = append(res.Affected, stateKey) affected = append(affected, stateKey)
prevEvents = []gomatrixserverlib.EventReference{ prevEvents = []gomatrixserverlib.EventReference{
event.EventReference(), event.EventReference(),
} }
@ -172,108 +141,85 @@ func (r *Admin) PerformAdminEvacuateRoom(
Asynchronous: true, Asynchronous: true,
} }
inputRes := &api.InputRoomEventsResponse{} inputRes := &api.InputRoomEventsResponse{}
return r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) err = r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
return affected, err
} }
// PerformAdminEvacuateUser will remove the given user from all rooms.
func (r *Admin) PerformAdminEvacuateUser( func (r *Admin) PerformAdminEvacuateUser(
ctx context.Context, ctx context.Context,
req *api.PerformAdminEvacuateUserRequest, userID string,
res *api.PerformAdminEvacuateUserResponse, ) (affected []string, err error) {
) error { _, domain, err := gomatrixserverlib.SplitID('@', userID)
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Malformed user ID: %s", err),
}
return nil
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
res.Error = &api.PerformError{ return nil, fmt.Errorf("can only evacuate local users using this endpoint")
Code: api.PerformErrorBadRequest,
Msg: "Can only evacuate local users using this endpoint",
}
return nil
} }
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, spec.Join) roomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Join)
if err != nil {
return nil, err
}
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Invite)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
return nil
} }
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, spec.Invite) allRooms := append(roomIDs, inviteRoomIDs...)
if err != nil && err != sql.ErrNoRows { affected = make([]string, 0, len(allRooms))
res.Error = &api.PerformError{ for _, roomID := range allRooms {
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
return nil
}
for _, roomID := range append(roomIDs, inviteRoomIDs...) {
leaveReq := &api.PerformLeaveRequest{ leaveReq := &api.PerformLeaveRequest{
RoomID: roomID, RoomID: roomID,
UserID: req.UserID, UserID: userID,
} }
leaveRes := &api.PerformLeaveResponse{} leaveRes := &api.PerformLeaveResponse{}
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err),
}
return nil
} }
res.Affected = append(res.Affected, roomID) affected = append(affected, roomID)
if len(outputEvents) == 0 { if len(outputEvents) == 0 {
continue continue
} }
if err := r.Inputer.OutputProducer.ProduceRoomEvents(roomID, outputEvents); err != nil { if err := r.Inputer.OutputProducer.ProduceRoomEvents(roomID, outputEvents); err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err),
}
return nil
} }
} }
return nil return affected, nil
} }
// PerformAdminPurgeRoom removes all traces for the given room from the database.
func (r *Admin) PerformAdminPurgeRoom( func (r *Admin) PerformAdminPurgeRoom(
ctx context.Context, ctx context.Context,
req *api.PerformAdminPurgeRoomRequest, roomID string,
res *api.PerformAdminPurgeRoomResponse,
) error { ) error {
// Validate we actually got a room ID and nothing else // Validate we actually got a room ID and nothing else
if _, _, err := gomatrixserverlib.SplitID('!', req.RoomID); err != nil { if _, _, err := gomatrixserverlib.SplitID('!', roomID); err != nil {
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Malformed room ID: %s", err),
}
return nil
} }
logrus.WithField("room_id", req.RoomID).Warn("Purging room from roomserver") // Evacuate the room before purging it from the database
if err := r.DB.PurgeRoom(ctx, req.RoomID); err != nil { if _, err := r.PerformAdminEvacuateRoom(ctx, roomID); err != nil {
logrus.WithField("room_id", req.RoomID).WithError(err).Warn("Failed to purge room from roomserver") logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to evacuate room before purging")
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: err.Error(),
}
return nil
} }
logrus.WithField("room_id", req.RoomID).Warn("Room purged from roomserver") logrus.WithField("room_id", roomID).Warn("Purging room from roomserver")
if err := r.DB.PurgeRoom(ctx, roomID); err != nil {
logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to purge room from roomserver")
return err
}
return r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{ logrus.WithField("room_id", roomID).Warn("Room purged from roomserver")
return r.Inputer.OutputProducer.ProduceRoomEvents(roomID, []api.OutputEvent{
{ {
Type: api.OutputTypePurgeRoom, Type: api.OutputTypePurgeRoom,
PurgeRoom: &api.OutputPurgeRoom{ PurgeRoom: &api.OutputPurgeRoom{
RoomID: req.RoomID, RoomID: roomID,
}, },
}, },
}) })
@ -281,42 +227,25 @@ func (r *Admin) PerformAdminPurgeRoom(
func (r *Admin) PerformAdminDownloadState( func (r *Admin) PerformAdminDownloadState(
ctx context.Context, ctx context.Context,
req *api.PerformAdminDownloadStateRequest, roomID, userID string, serverName spec.ServerName,
res *api.PerformAdminDownloadStateResponse,
) error { ) error {
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', req.UserID) _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Cfg.Matrix.SplitLocalID: %s", err),
}
return nil
} }
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) roomInfo, err := r.DB.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
return nil
} }
if roomInfo == nil || roomInfo.IsStub() { if roomInfo == nil || roomInfo.IsStub() {
res.Error = &api.PerformError{ return eventutil.ErrRoomNoExists
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("room %q not found", req.RoomID),
}
return nil
} }
fwdExtremities, _, depth, err := r.DB.LatestEventIDs(ctx, roomInfo.RoomNID) fwdExtremities, _, depth, err := r.DB.LatestEventIDs(ctx, roomInfo.RoomNID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.LatestEventIDs: %s", err),
}
return nil
} }
authEventMap := map[string]*gomatrixserverlib.Event{} authEventMap := map[string]*gomatrixserverlib.Event{}
@ -324,54 +253,46 @@ func (r *Admin) PerformAdminDownloadState(
for _, fwdExtremity := range fwdExtremities { for _, fwdExtremity := range fwdExtremities {
var state gomatrixserverlib.StateResponse var state gomatrixserverlib.StateResponse
state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, serverName, roomID, fwdExtremity.EventID, roomInfo.RoomVersion)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity.EventID, err)
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity.EventID, err),
}
return nil
} }
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = authEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil {
continue continue
} }
authEventMap[authEvent.EventID()] = authEvent authEventMap[authEvent.EventID()] = authEvent
} }
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = stateEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil {
continue continue
} }
stateEventMap[stateEvent.EventID()] = stateEvent stateEventMap[stateEvent.EventID()] = stateEvent
} }
} }
authEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(authEventMap)) authEvents := make([]*types.HeaderedEvent, 0, len(authEventMap))
stateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEventMap)) stateEvents := make([]*types.HeaderedEvent, 0, len(stateEventMap))
stateIDs := make([]string, 0, len(stateEventMap)) stateIDs := make([]string, 0, len(stateEventMap))
for _, authEvent := range authEventMap { for _, authEvent := range authEventMap {
authEvents = append(authEvents, authEvent.Headered(roomInfo.RoomVersion)) authEvents = append(authEvents, &types.HeaderedEvent{Event: authEvent})
} }
for _, stateEvent := range stateEventMap { for _, stateEvent := range stateEventMap {
stateEvents = append(stateEvents, stateEvent.Headered(roomInfo.RoomVersion)) stateEvents = append(stateEvents, &types.HeaderedEvent{Event: stateEvent})
stateIDs = append(stateIDs, stateEvent.EventID()) stateIDs = append(stateIDs, stateEvent.EventID())
} }
builder := &gomatrixserverlib.EventBuilder{ builder := &gomatrixserverlib.EventBuilder{
Type: "org.matrix.dendrite.state_download", Type: "org.matrix.dendrite.state_download",
Sender: req.UserID, Sender: userID,
RoomID: req.RoomID, RoomID: roomID,
Content: spec.RawJSON("{}"), Content: spec.RawJSON("{}"),
} }
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err)
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
return nil
} }
queryRes := &api.QueryLatestEventsAndStateResponse{ queryRes := &api.QueryLatestEventsAndStateResponse{
@ -389,11 +310,7 @@ func (r *Admin) PerformAdminDownloadState(
ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes) ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return fmt.Errorf("eventutil.BuildEvent: %w", err)
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
return nil
} }
inputReq := &api.InputRoomEventsRequest{ inputReq := &api.InputRoomEventsRequest{
@ -417,19 +334,12 @@ func (r *Admin) PerformAdminDownloadState(
SendAsServer: string(r.Cfg.Matrix.ServerName), SendAsServer: string(r.Cfg.Matrix.ServerName),
}) })
if err := r.Inputer.InputRoomEvents(ctx, inputReq, inputRes); err != nil { if err = r.Inputer.InputRoomEvents(ctx, inputReq, inputRes); err != nil {
res.Error = &api.PerformError{ return fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.InputRoomEvents: %s", err),
}
return nil
} }
if inputRes.ErrMsg != "" { if inputRes.ErrMsg != "" {
res.Error = &api.PerformError{ return inputRes.Err()
Code: api.PerformErrorBadRequest,
Msg: inputRes.ErrMsg,
}
} }
return nil return nil

View file

@ -99,7 +99,7 @@ func (r *Backfiller) PerformBackfill(
if _, ok := redactEventIDs[event.EventID()]; ok { if _, ok := redactEventIDs[event.EventID()]; ok {
event.Redact() event.Redact()
} }
response.Events = append(response.Events, event.Headered(info.RoomVersion)) response.Events = append(response.Events, &types.HeaderedEvent{Event: event})
} }
return err return err
@ -133,7 +133,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
// persist these new events - auth checks have already been done // persist these new events - auth checks have already been done
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) roomNID, backfilledEventMap := persistEvents(ctx, r.DB, gomatrixserverlib.TempCastToEvents(events))
for _, ev := range backfilledEventMap { for _, ev := range backfilledEventMap {
// now add state for these events // now add state for these events
@ -168,7 +168,10 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
res.Events = events res.Events = make([]*types.HeaderedEvent, len(events))
for i := range events {
res.Events[i] = &types.HeaderedEvent{Event: events[i].(*gomatrixserverlib.Event)}
}
res.HistoryVisibility = requester.historyVisiblity res.HistoryVisibility = requester.historyVisiblity
return nil return nil
} }
@ -186,7 +189,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
util.GetLogger(ctx).WithError(err).Warn("cannot query missing events") util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
return return
} }
missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event missingMap := make(map[string]*types.HeaderedEvent) // id -> event
for _, id := range stateIDs { for _, id := range stateIDs {
if _, ok := nidMap[id]; !ok { if _, ok := nidMap[id]; !ok {
missingMap[id] = nil missingMap[id] = nil
@ -227,15 +230,15 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
logger.WithError(err).Warn("event failed PDU checks") logger.WithError(err).Warn("event failed PDU checks")
continue continue
} }
missingMap[id] = res.Event missingMap[id] = &types.HeaderedEvent{Event: res.Event.(*gomatrixserverlib.Event)}
} }
} }
} }
var newEvents []*gomatrixserverlib.HeaderedEvent var newEvents []*gomatrixserverlib.Event
for _, ev := range missingMap { for _, ev := range missingMap {
if ev != nil { if ev != nil {
newEvents = append(newEvents, ev) newEvents = append(newEvents, ev.Event)
} }
} }
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
@ -254,7 +257,7 @@ type backfillRequester struct {
// per-request state // per-request state
servers []spec.ServerName servers []spec.ServerName
eventIDToBeforeStateIDs map[string][]string eventIDToBeforeStateIDs map[string][]string
eventIDMap map[string]*gomatrixserverlib.Event eventIDMap map[string]gomatrixserverlib.PDU
historyVisiblity gomatrixserverlib.HistoryVisibility historyVisiblity gomatrixserverlib.HistoryVisibility
roomInfo types.RoomInfo roomInfo types.RoomInfo
} }
@ -275,15 +278,15 @@ func newBackfillRequester(
virtualHost: virtualHost, virtualHost: virtualHost,
isLocalServerName: isLocalServerName, isLocalServerName: isLocalServerName,
eventIDToBeforeStateIDs: make(map[string][]string), eventIDToBeforeStateIDs: make(map[string][]string),
eventIDMap: make(map[string]*gomatrixserverlib.Event), eventIDMap: make(map[string]gomatrixserverlib.PDU),
bwExtrems: bwExtrems, bwExtrems: bwExtrems,
preferServer: preferServer, preferServer: preferServer,
historyVisiblity: gomatrixserverlib.HistoryVisibilityShared, historyVisiblity: gomatrixserverlib.HistoryVisibilityShared,
} }
} }
func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent *gomatrixserverlib.HeaderedEvent) ([]string, error) { func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.PDU) ([]string, error) {
b.eventIDMap[targetEvent.EventID()] = targetEvent.Unwrap() b.eventIDMap[targetEvent.EventID()] = targetEvent
if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok {
return ids, nil return ids, nil
} }
@ -305,7 +308,7 @@ func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent
if !ok { if !ok {
goto FederationHit goto FederationHit
} }
newStateIDs := b.calculateNewStateIDs(targetEvent.Unwrap(), prevEvent, prevEventStateIDs) newStateIDs := b.calculateNewStateIDs(targetEvent, prevEvent, prevEventStateIDs)
if newStateIDs != nil { if newStateIDs != nil {
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
return newStateIDs, nil return newStateIDs, nil
@ -334,7 +337,7 @@ FederationHit:
return nil, lastErr return nil, lastErr
} }
func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent *gomatrixserverlib.Event, prevEventStateIDs []string) []string { func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.PDU, prevEventStateIDs []string) []string {
newStateIDs := prevEventStateIDs[:] newStateIDs := prevEventStateIDs[:]
if prevEvent.StateKey() == nil { if prevEvent.StateKey() == nil {
// state is the same as the previous event // state is the same as the previous event
@ -372,7 +375,7 @@ func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent *gomatri
} }
func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
event *gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { event gomatrixserverlib.PDU, eventIDs []string) (map[string]gomatrixserverlib.PDU, error) {
// try to fetch the events from the database first // try to fetch the events from the database first
events, err := b.ProvideEvents(roomVer, eventIDs) events, err := b.ProvideEvents(roomVer, eventIDs)
@ -382,7 +385,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
} else { } else {
logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs)) logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs))
if len(events) == len(eventIDs) { if len(events) == len(eventIDs) {
result := make(map[string]*gomatrixserverlib.Event) result := make(map[string]gomatrixserverlib.PDU)
for i := range events { for i := range events {
result[events[i].EventID()] = events[i] result[events[i].EventID()] = events[i]
b.eventIDMap[events[i].EventID()] = events[i] b.eventIDMap[events[i].EventID()] = events[i]
@ -513,7 +516,7 @@ func (b *backfillRequester) Backfill(ctx context.Context, origin, server spec.Se
return tx, err return tx, err
} }
func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) { func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.PDU, error) {
ctx := context.Background() ctx := context.Background()
nidMap, err := b.db.EventNIDs(ctx, eventIDs) nidMap, err := b.db.EventNIDs(ctx, eventIDs)
if err != nil { if err != nil {
@ -535,7 +538,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err return nil, err
} }
events := make([]*gomatrixserverlib.Event, len(eventsWithNids)) events := make([]gomatrixserverlib.PDU, len(eventsWithNids))
for i := range eventsWithNids { for i := range eventsWithNids {
events[i] = eventsWithNids[i].Event events[i] = eventsWithNids[i].Event
} }
@ -587,7 +590,7 @@ func joinEventsFromHistoryVisibility(
return evs, visibility, err return evs, visibility, err
} }
func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.Event) (types.RoomNID, map[string]types.Event) {
var roomNID types.RoomNID var roomNID types.RoomNID
var eventNID types.EventNID var eventNID types.EventNID
backfilledEventMap := make(map[string]types.Event) backfilledEventMap := make(map[string]types.Event)
@ -604,7 +607,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
i++ i++
} }
roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev.Unwrap()) roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev)
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to get or create roomNID") logrus.WithError(err).Error("failed to get or create roomNID")
continue continue
@ -623,7 +626,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
continue continue
} }
eventNID, _, err = db.StoreEvent(ctx, ev.Unwrap(), roomInfo, eventTypeNID, eventStateKeyNID, authNids, false) eventNID, _, err = db.StoreEvent(ctx, ev, roomInfo, eventTypeNID, eventStateKeyNID, authNids, false)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue continue
@ -631,7 +634,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
resolver := state.NewStateResolution(db, roomInfo) resolver := state.NewStateResolution(db, roomInfo)
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Unwrap(), &resolver) _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
continue continue
@ -640,12 +643,12 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
// It's also possible for this event to be a redaction which results in another event being // It's also possible for this event to be a redaction which results in another event being
// redacted, which we don't care about since we aren't returning it in this backfill. // redacted, which we don't care about since we aren't returning it in this backfill.
if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() { if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() {
ev = redactedEvent.Headered(ev.RoomVersion) ev = redactedEvent
events[j] = ev events[j] = ev
} }
backfilledEventMap[ev.EventID()] = types.Event{ backfilledEventMap[ev.EventID()] = types.Event{
EventNID: eventNID, EventNID: eventNID,
Event: ev.Unwrap(), Event: ev,
} }
} }
return roomNID, backfilledEventMap return roomNID, backfilledEventMap

View file

@ -68,7 +68,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil { if err != nil {
return err return err
} }
var sortedLatestEvents []*gomatrixserverlib.Event var sortedLatestEvents []gomatrixserverlib.PDU
for _, ev := range latestEvents { for _, ev := range latestEvents {
sortedLatestEvents = append(sortedLatestEvents, ev.Event) sortedLatestEvents = append(sortedLatestEvents, ev.Event)
} }
@ -76,7 +76,7 @@ func (r *InboundPeeker) PerformInboundPeek(
sortedLatestEvents, sortedLatestEvents,
gomatrixserverlib.TopologicalOrderByPrevEvents, gomatrixserverlib.TopologicalOrderByPrevEvents,
) )
response.LatestEvent = sortedLatestEvents[0].Headered(info.RoomVersion) response.LatestEvent = &types.HeaderedEvent{Event: sortedLatestEvents[0].(*gomatrixserverlib.Event)}
// XXX: do we actually need to do a state resolution here? // XXX: do we actually need to do a state resolution here?
roomState := state.NewStateResolution(r.DB, info) roomState := state.NewStateResolution(r.DB, info)
@ -106,11 +106,11 @@ func (r *InboundPeeker) PerformInboundPeek(
} }
for _, event := range stateEvents { for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event})
} }
for _, event := range authEvents { for _, event := range authEvents {
response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion)) response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{Event: event})
} }
err = r.Inputer.OutputProducer.ProduceRoomEvents(request.RoomID, []api.OutputEvent{ err = r.Inputer.OutputProducer.ProduceRoomEvents(request.RoomID, []api.OutputEvent{

View file

@ -45,7 +45,6 @@ type Inviter struct {
func (r *Inviter) PerformInvite( func (r *Inviter) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
res *api.PerformInviteResponse,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var outputUpdates []api.OutputEvent var outputUpdates []api.OutputEvent
event := req.Event event := req.Event
@ -66,20 +65,12 @@ func (r *Inviter) PerformInvite(
_, domain, err := gomatrixserverlib.SplitID('@', targetUserID) _, domain, err := gomatrixserverlib.SplitID('@', targetUserID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", targetUserID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("The user ID %q is invalid!", targetUserID),
}
return nil, nil
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain)
isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain) isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain)
if !isOriginLocal && !isTargetLocal { if !isOriginLocal && !isTargetLocal {
res.Error = &api.PerformError{ return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be either from or to a local user")}
Code: api.PerformErrorBadRequest,
Msg: "The invite must be either from or to a local user",
}
return nil, nil
} }
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
@ -119,8 +110,8 @@ func (r *Inviter) PerformInvite(
} }
outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{ outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{
EventNID: 0, EventNID: 0,
Event: event.Unwrap(), Event: event.Event,
}, outputUpdates, req.Event.RoomVersion) }, outputUpdates, req.Event.Version())
if err != nil { if err != nil {
return nil, fmt.Errorf("updateToInviteMembership: %w", err) return nil, fmt.Errorf("updateToInviteMembership: %w", err)
} }
@ -175,12 +166,8 @@ func (r *Inviter) PerformInvite(
// For now we will implement option 2. Since in the abesence of a retry // For now we will implement option 2. Since in the abesence of a retry
// mechanism it will be equivalent to option 1, and we don't have a // mechanism it will be equivalent to option 1, and we don't have a
// signalling mechanism to implement option 3. // signalling mechanism to implement option 3.
res.Error = &api.PerformError{
Code: api.PerformErrorNotAllowed,
Msg: "User is already joined to room",
}
logger.Debugf("user already joined") logger.Debugf("user already joined")
return nil, nil return nil, api.ErrNotAllowed{Err: fmt.Errorf("user is already joined to room")}
} }
// If the invite originated remotely then we can't send an // If the invite originated remotely then we can't send an
@ -201,11 +188,7 @@ func (r *Inviter) PerformInvite(
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"processInviteEvent.checkAuthEvents failed for event", "processInviteEvent.checkAuthEvents failed for event",
) )
res.Error = &api.PerformError{ return nil, api.ErrNotAllowed{Err: err}
Msg: err.Error(),
Code: api.PerformErrorNotAllowed,
}
return nil, nil
} }
// If the invite originated from us and the target isn't local then we // If the invite originated from us and the target isn't local then we
@ -220,12 +203,8 @@ func (r *Inviter) PerformInvite(
} }
fsRes := &federationAPI.PerformInviteResponse{} fsRes := &federationAPI.PerformInviteResponse{}
if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil {
res.Error = &api.PerformError{
Msg: err.Error(),
Code: api.PerformErrorNotAllowed,
}
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed")
return nil, nil return nil, api.ErrNotAllowed{Err: err}
} }
event = fsRes.Event event = fsRes.Event
logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID()) logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID())
@ -251,11 +230,8 @@ func (r *Inviter) PerformInvite(
return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
} }
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),
Code: api.PerformErrorNotAllowed,
}
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
return nil, api.ErrNotAllowed{Err: err}
} }
// Don't notify the sync api of this event in the same way as a federated invite so the invitee // Don't notify the sync api of this event in the same way as a federated invite so the invitee
@ -300,7 +276,7 @@ func buildInviteStrippedState(
inviteState := []fclient.InviteV2StrippedState{ inviteState := []fclient.InviteV2StrippedState{
fclient.NewInviteV2StrippedState(input.Event.Event), fclient.NewInviteV2StrippedState(input.Event.Event),
} }
stateEvents = append(stateEvents, types.Event{Event: input.Event.Unwrap()}) stateEvents = append(stateEvents, types.Event{Event: input.Event.Event})
for _, event := range stateEvents { for _, event := range stateEvents {
inviteState = append(inviteState, fclient.NewInviteV2StrippedState(event.Event)) inviteState = append(inviteState, fclient.NewInviteV2StrippedState(event.Event))
} }

View file

@ -36,6 +36,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
@ -53,32 +54,22 @@ type Joiner struct {
func (r *Joiner) PerformJoin( func (r *Joiner) PerformJoin(
ctx context.Context, ctx context.Context,
req *rsAPI.PerformJoinRequest, req *rsAPI.PerformJoinRequest,
res *rsAPI.PerformJoinResponse, ) (roomID string, joinedVia spec.ServerName, err error) {
) error {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomIDOrAlias, "room_id": req.RoomIDOrAlias,
"user_id": req.UserID, "user_id": req.UserID,
"servers": req.ServerNames, "servers": req.ServerNames,
}) })
logger.Info("User requested to room join") logger.Info("User requested to room join")
roomID, joinedVia, err := r.performJoin(context.Background(), req) roomID, joinedVia, err = r.performJoin(context.Background(), req)
if err != nil { if err != nil {
logger.WithError(err).Error("Failed to join room") logger.WithError(err).Error("Failed to join room")
sentry.CaptureException(err) sentry.CaptureException(err)
perr, ok := err.(*rsAPI.PerformError) return "", "", err
if ok {
res.Error = perr
} else {
res.Error = &rsAPI.PerformError{
Msg: err.Error(),
}
}
return nil
} }
logger.Info("User joined room successfully") logger.Info("User joined room successfully")
res.RoomID = roomID
res.JoinedVia = joinedVia return roomID, joinedVia, nil
return nil
} }
func (r *Joiner) performJoin( func (r *Joiner) performJoin(
@ -87,16 +78,10 @@ func (r *Joiner) performJoin(
) (string, spec.ServerName, error) { ) (string, spec.ServerName, error) {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
}
} }
if strings.HasPrefix(req.RoomIDOrAlias, "!") { if strings.HasPrefix(req.RoomIDOrAlias, "!") {
return r.performJoinRoomByID(ctx, req) return r.performJoinRoomByID(ctx, req)
@ -104,10 +89,7 @@ func (r *Joiner) performJoin(
if strings.HasPrefix(req.RoomIDOrAlias, "#") { if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performJoinRoomByAlias(ctx, req) return r.performJoinRoomByAlias(ctx, req)
} }
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
}
} }
func (r *Joiner) performJoinRoomByAlias( func (r *Joiner) performJoinRoomByAlias(
@ -182,10 +164,7 @@ func (r *Joiner) performJoinRoomByID(
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err),
}
} }
// If the server name in the room ID isn't ours then it's a // If the server name in the room ID isn't ours then it's a
@ -199,10 +178,7 @@ func (r *Joiner) performJoinRoomByID(
userID := req.UserID userID := req.UserID
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", userID, err)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("User ID %q is invalid: %s", userID, err),
}
} }
eb := gomatrixserverlib.EventBuilder{ eb := gomatrixserverlib.EventBuilder{
Type: spec.MRoomMember, Type: spec.MRoomMember,
@ -273,7 +249,7 @@ func (r *Joiner) performJoinRoomByID(
// If a guest is trying to join a room, check that the room has a m.room.guest_access event // If a guest is trying to join a room, check that the room has a m.room.guest_access event
if req.IsGuest { if req.IsGuest {
var guestAccessEvent *gomatrixserverlib.HeaderedEvent var guestAccessEvent *types.HeaderedEvent
guestAccess := "forbidden" guestAccess := "forbidden"
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, spec.MRoomGuestAccess, "") guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, spec.MRoomGuestAccess, "")
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil { if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
@ -286,10 +262,7 @@ func (r *Joiner) performJoinRoomByID(
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
// is present on the room and has the guest_access value can_join. // is present on the room and has the guest_access value can_join.
if guestAccess != "can_join" { if guestAccess != "can_join" {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("guest access is forbidden")}
Code: rsAPI.PerformErrorNotAllowed,
Msg: "Guest access is forbidden",
}
} }
} }
@ -334,23 +307,17 @@ func (r *Joiner) performJoinRoomByID(
InputRoomEvents: []rsAPI.InputRoomEvent{ InputRoomEvents: []rsAPI.InputRoomEvent{
{ {
Kind: rsAPI.KindNew, Kind: rsAPI.KindNew,
Event: event.Headered(buildRes.RoomVersion), Event: event,
SendAsServer: string(userDomain), SendAsServer: string(userDomain),
}, },
}, },
} }
inputRes := rsAPI.InputRoomEventsResponse{} inputRes := rsAPI.InputRoomEventsResponse{}
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrNotAllowed{Err: err}
Code: rsAPI.PerformErrorNoOperation,
Msg: fmt.Sprintf("InputRoomEvents failed: %s", err),
}
} }
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrNotAllowed{Err: err}
Code: rsAPI.PerformErrorNotAllowed,
Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err),
}
} }
} }
@ -363,10 +330,7 @@ func (r *Joiner) performJoinRoomByID(
// Otherwise we'll try a federated join as normal, since it's quite // Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers. // possible that the room still exists on other servers.
if len(req.ServerNames) == 0 { if len(req.ServerNames) == 0 {
return "", "", &rsAPI.PerformError{ return "", "", eventutil.ErrRoomNoExists
Code: rsAPI.PerformErrorNoRoom,
Msg: fmt.Sprintf("room ID %q does not exist", req.RoomIDOrAlias),
}
} }
} }
@ -401,11 +365,7 @@ func (r *Joiner) performFederatedJoinRoomByID(
fedRes := fsAPI.PerformJoinResponse{} fedRes := fsAPI.PerformJoinResponse{}
r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)
if fedRes.LastError != nil { if fedRes.LastError != nil {
return "", &rsAPI.PerformError{ return "", fedRes.LastError
Code: rsAPI.PerformErrRemote,
Msg: fedRes.LastError.Message,
RemoteCode: fedRes.LastError.Code,
}
} }
return fedRes.JoinedVia, nil return fedRes.JoinedVia, nil
} }
@ -429,10 +389,7 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
return "", nil return "", nil
} }
if !res.Allowed { if !res.Allowed {
return "", &rsAPI.PerformError{ return "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("the join to room %s was not allowed", joinReq.RoomIDOrAlias)}
Code: rsAPI.PerformErrorNotAllowed,
Msg: fmt.Sprintf("The join to room %s was not allowed.", joinReq.RoomIDOrAlias),
}
} }
return res.AuthorisedVia, nil return res.AuthorisedVia, nil
} }

View file

@ -196,7 +196,7 @@ func (r *Leaver) performLeaveRoomByID(
InputRoomEvents: []api.InputRoomEvent{ InputRoomEvents: []api.InputRoomEvent{
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event.Headered(buildRes.RoomVersion), Event: event,
Origin: senderDomain, Origin: senderDomain,
SendAsServer: string(senderDomain), SendAsServer: string(senderDomain),
}, },

View file

@ -44,21 +44,8 @@ type Peeker struct {
func (r *Peeker) PerformPeek( func (r *Peeker) PerformPeek(
ctx context.Context, ctx context.Context,
req *api.PerformPeekRequest, req *api.PerformPeekRequest,
res *api.PerformPeekResponse, ) (roomID string, err error) {
) error { return r.performPeek(ctx, req)
roomID, err := r.performPeek(ctx, req)
if err != nil {
perr, ok := err.(*api.PerformError)
if ok {
res.Error = perr
} else {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
}
res.RoomID = roomID
return nil
} }
func (r *Peeker) performPeek( func (r *Peeker) performPeek(
@ -68,16 +55,10 @@ func (r *Peeker) performPeek(
// FIXME: there's way too much duplication with performJoin // FIXME: there's way too much duplication with performJoin
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
}
} }
if strings.HasPrefix(req.RoomIDOrAlias, "!") { if strings.HasPrefix(req.RoomIDOrAlias, "!") {
return r.performPeekRoomByID(ctx, req) return r.performPeekRoomByID(ctx, req)
@ -85,10 +66,7 @@ func (r *Peeker) performPeek(
if strings.HasPrefix(req.RoomIDOrAlias, "#") { if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performPeekRoomByAlias(ctx, req) return r.performPeekRoomByAlias(ctx, req)
} }
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
}
} }
func (r *Peeker) performPeekRoomByAlias( func (r *Peeker) performPeekRoomByAlias(
@ -98,7 +76,7 @@ func (r *Peeker) performPeekRoomByAlias(
// Get the domain part of the room alias. // Get the domain part of the room alias.
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias) return "", api.ErrInvalidID{Err: fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias)}
} }
req.ServerNames = append(req.ServerNames, domain) req.ServerNames = append(req.ServerNames, domain)
@ -147,10 +125,7 @@ func (r *Peeker) performPeekRoomByID(
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', roomID) _, domain, err := gomatrixserverlib.SplitID('!', roomID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", roomID, err)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", roomID, err),
}
} }
// handle federated peeks // handle federated peeks
@ -169,11 +144,7 @@ func (r *Peeker) performPeekRoomByID(
fedRes := fsAPI.PerformOutboundPeekResponse{} fedRes := fsAPI.PerformOutboundPeekResponse{}
_ = r.FSAPI.PerformOutboundPeek(ctx, &fedReq, &fedRes) _ = r.FSAPI.PerformOutboundPeek(ctx, &fedReq, &fedRes)
if fedRes.LastError != nil { if fedRes.LastError != nil {
return "", &api.PerformError{ return "", fedRes.LastError
Code: api.PerformErrRemote,
Msg: fedRes.LastError.Message,
RemoteCode: fedRes.LastError.Code,
}
} }
} }
@ -194,17 +165,11 @@ func (r *Peeker) performPeekRoomByID(
} }
if !worldReadable { if !worldReadable {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("room is not world-readable")}
Code: api.PerformErrorNotAllowed,
Msg: "Room is not world-readable",
}
} }
if ev, _ := r.DB.GetStateEvent(ctx, roomID, "m.room.encryption", ""); ev != nil { if ev, _ := r.DB.GetStateEvent(ctx, roomID, "m.room.encryption", ""); ev != nil {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("Cannot peek into an encrypted room")}
Code: api.PerformErrorNotAllowed,
Msg: "Cannot peek into an encrypted room",
}
} }
// TODO: handle federated peeks // TODO: handle federated peeks

View file

@ -25,16 +25,10 @@ type Publisher struct {
DB storage.Database DB storage.Database
} }
// PerformPublish publishes or unpublishes a room from the room directory. Returns a database error, if any.
func (r *Publisher) PerformPublish( func (r *Publisher) PerformPublish(
ctx context.Context, ctx context.Context,
req *api.PerformPublishRequest, req *api.PerformPublishRequest,
res *api.PerformPublishResponse,
) error { ) error {
err := r.DB.PublishRoom(ctx, req.RoomID, req.AppserviceID, req.NetworkID, req.Visibility == "public") return r.DB.PublishRoom(ctx, req.RoomID, req.AppserviceID, req.NetworkID, req.Visibility == "public")
if err != nil {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
return nil
} }

View file

@ -34,84 +34,48 @@ type Unpeeker struct {
Inputer *input.Inputer Inputer *input.Inputer
} }
// PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationapi. // PerformUnpeek handles un-peeking matrix rooms, including over federation by talking to the federationapi.
func (r *Unpeeker) PerformUnpeek( func (r *Unpeeker) PerformUnpeek(
ctx context.Context, ctx context.Context,
req *api.PerformUnpeekRequest, roomID, userID, deviceID string,
res *api.PerformUnpeekResponse,
) error {
if err := r.performUnpeek(ctx, req); err != nil {
perr, ok := err.(*api.PerformError)
if ok {
res.Error = perr
} else {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
}
return nil
}
func (r *Unpeeker) performUnpeek(
ctx context.Context,
req *api.PerformUnpeekRequest,
) error { ) error {
// FIXME: there's way too much duplication with performJoin // FIXME: there's way too much duplication with performJoin
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return &api.PerformError{ return api.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", userID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return &api.PerformError{ return api.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", userID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
}
} }
if strings.HasPrefix(req.RoomID, "!") { if strings.HasPrefix(roomID, "!") {
return r.performUnpeekRoomByID(ctx, req) return r.performUnpeekRoomByID(ctx, roomID, userID, deviceID)
}
return &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid", req.RoomID),
} }
return api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid", roomID)}
} }
func (r *Unpeeker) performUnpeekRoomByID( func (r *Unpeeker) performUnpeekRoomByID(
_ context.Context, _ context.Context,
req *api.PerformUnpeekRequest, roomID, userID, deviceID string,
) (err error) { ) (err error) {
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, _, err = gomatrixserverlib.SplitID('!', req.RoomID) _, _, err = gomatrixserverlib.SplitID('!', roomID)
if err != nil { if err != nil {
return &api.PerformError{ return api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", roomID, err)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomID, err),
}
} }
// TODO: handle federated peeks // TODO: handle federated peeks
err = r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{
{
Type: api.OutputTypeRetirePeek,
RetirePeek: &api.OutputRetirePeek{
RoomID: req.RoomID,
UserID: req.UserID,
DeviceID: req.DeviceID,
},
},
})
if err != nil {
return
}
// By this point, if req.RoomIDOrAlias contained an alias, then // By this point, if req.RoomIDOrAlias contained an alias, then
// it will have been overwritten with a room ID by performPeekRoomByAlias. // it will have been overwritten with a room ID by performPeekRoomByAlias.
// We should now include this in the response so that the CS API can // We should now include this in the response so that the CS API can
// return the right room ID. // return the right room ID.
return nil return r.Inputer.OutputProducer.ProduceRoomEvents(roomID, []api.OutputEvent{
{
Type: api.OutputTypeRetirePeek,
RetirePeek: &api.OutputRetirePeek{
RoomID: roomID,
UserID: userID,
DeviceID: deviceID,
},
},
})
} }

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -44,46 +45,29 @@ type fledglingEvent struct {
// PerformRoomUpgrade upgrades a room from one version to another // PerformRoomUpgrade upgrades a room from one version to another
func (r *Upgrader) PerformRoomUpgrade( func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context, ctx context.Context,
req *api.PerformRoomUpgradeRequest, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion,
res *api.PerformRoomUpgradeResponse, ) (newRoomID string, err error) {
) error { return r.performRoomUpgrade(ctx, roomID, userID, roomVersion)
res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req)
if res.Error != nil {
res.NewRoomID = ""
logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed")
}
return nil
} }
func (r *Upgrader) performRoomUpgrade( func (r *Upgrader) performRoomUpgrade(
ctx context.Context, ctx context.Context,
req *api.PerformRoomUpgradeRequest, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion,
) (string, *api.PerformError) { ) (string, error) {
roomID := req.RoomID
userID := req.UserID
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("error validating the user ID")}
Code: api.PerformErrorNotAllowed,
Msg: "Error validating the user ID",
}
} }
evTime := time.Now() evTime := time.Now()
// Return an immediate error if the room does not exist // Return an immediate error if the room does not exist
if err := r.validateRoomExists(ctx, roomID); err != nil { if err := r.validateRoomExists(ctx, roomID); err != nil {
return "", &api.PerformError{ return "", err
Code: api.PerformErrorNoRoom,
Msg: "Error validating that the room exists",
}
} }
// 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone) // 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone)
if !r.userIsAuthorized(ctx, userID, roomID) { if !r.userIsAuthorized(ctx, userID, roomID) {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")}
Code: api.PerformErrorNotAllowed,
Msg: "You don't have permission to upgrade the room, power level too low.",
}
} }
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
@ -96,9 +80,7 @@ func (r *Upgrader) performRoomUpgrade(
} }
oldRoomRes := &api.QueryLatestEventsAndStateResponse{} oldRoomRes := &api.QueryLatestEventsAndStateResponse{}
if err := r.URSAPI.QueryLatestEventsAndState(ctx, oldRoomReq, oldRoomRes); err != nil { if err := r.URSAPI.QueryLatestEventsAndState(ctx, oldRoomReq, oldRoomRes); err != nil {
return "", &api.PerformError{ return "", fmt.Errorf("Failed to get latest state: %s", err)
Msg: fmt.Sprintf("Failed to get latest state: %s", err),
}
} }
// Make the tombstone event // Make the tombstone event
@ -109,13 +91,13 @@ func (r *Upgrader) performRoomUpgrade(
// Generate the initial events we need to send into the new room. This includes copied state events and bans // Generate the initial events we need to send into the new room. This includes copied state events and bans
// as well as the power level events needed to set up the room // as well as the power level events needed to set up the room
eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, string(req.RoomVersion), tombstoneEvent) eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, roomVersion, tombstoneEvent)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Send the setup events to the new room // Send the setup events to the new room
if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, roomVersion, eventsToMake); pErr != nil {
return "", pErr return "", pErr
} }
@ -147,22 +129,15 @@ func (r *Upgrader) performRoomUpgrade(
return newRoomID, nil return newRoomID, nil
} }
func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*gomatrixserverlib.PowerLevelContent, *api.PerformError) { func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*gomatrixserverlib.PowerLevelContent, error) {
oldPowerLevelsEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ oldPowerLevelsEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels, EventType: spec.MRoomPowerLevels,
StateKey: "", StateKey: "",
}) })
powerLevelContent, err := oldPowerLevelsEvent.PowerLevels() return oldPowerLevelsEvent.PowerLevels()
if err != nil {
util.GetLogger(ctx).WithError(err).Error()
return nil, &api.PerformError{
Msg: "Power level event was invalid or malformed",
}
}
return powerLevelContent, nil
} }
func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) *api.PerformError { func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error {
restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID)
if pErr != nil { if pErr != nil {
return pErr return pErr
@ -184,54 +159,46 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
StateKey: "", StateKey: "",
Content: restrictedPowerLevelContent, Content: restrictedPowerLevelContent,
}) })
if resErr != nil {
if resErr.Code == api.PerformErrorNotAllowed { switch resErr.(type) {
util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not restrict power levels in old room") case api.ErrNotAllowed:
} else { util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not restrict power levels in old room")
return resErr case nil:
} return r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers)
} else { default:
if resErr = r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil { return resErr
return resErr
}
} }
return nil return nil
} }
func moveLocalAliases(ctx context.Context, func moveLocalAliases(ctx context.Context,
roomID, newRoomID, userID string, roomID, newRoomID, userID string,
URSAPI api.RoomserverInternalAPI) *api.PerformError { URSAPI api.RoomserverInternalAPI,
var err error ) (err error) {
aliasReq := api.GetAliasesForRoomIDRequest{RoomID: roomID} aliasReq := api.GetAliasesForRoomIDRequest{RoomID: roomID}
aliasRes := api.GetAliasesForRoomIDResponse{} aliasRes := api.GetAliasesForRoomIDResponse{}
if err = URSAPI.GetAliasesForRoomID(ctx, &aliasReq, &aliasRes); err != nil { if err = URSAPI.GetAliasesForRoomID(ctx, &aliasReq, &aliasRes); err != nil {
return &api.PerformError{ return fmt.Errorf("Failed to get old room aliases: %w", err)
Msg: fmt.Sprintf("Failed to get old room aliases: %s", err),
}
} }
for _, alias := range aliasRes.Aliases { for _, alias := range aliasRes.Aliases {
removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias} removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias}
removeAliasRes := api.RemoveRoomAliasResponse{} removeAliasRes := api.RemoveRoomAliasResponse{}
if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil {
return &api.PerformError{ return fmt.Errorf("Failed to remove old room alias: %w", err)
Msg: fmt.Sprintf("Failed to remove old room alias: %s", err),
}
} }
setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID} setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID}
setAliasRes := api.SetRoomAliasResponse{} setAliasRes := api.SetRoomAliasResponse{}
if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil {
return &api.PerformError{ return fmt.Errorf("Failed to set new room alias: %w", err)
Msg: fmt.Sprintf("Failed to set new room alias: %s", err),
}
} }
} }
return nil return nil
} }
func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) *api.PerformError { func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error {
for _, event := range oldRoom.StateEvents { for _, event := range oldRoom.StateEvents {
if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") { if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") {
continue continue
@ -241,9 +208,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
AltAliases []string `json:"alt_aliases"` AltAliases []string `json:"alt_aliases"`
} }
if err := json.Unmarshal(event.Content(), &aliasContent); err != nil { if err := json.Unmarshal(event.Content(), &aliasContent); err != nil {
return &api.PerformError{ return fmt.Errorf("failed to unmarshal canonical aliases: %w", err)
Msg: fmt.Sprintf("Failed to unmarshal canonical aliases: %s", err),
}
} }
if aliasContent.Alias == "" && len(aliasContent.AltAliases) == 0 { if aliasContent.Alias == "" && len(aliasContent.AltAliases) == 0 {
// There are no canonical aliases to clear, therefore do nothing. // There are no canonical aliases to clear, therefore do nothing.
@ -255,30 +220,25 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
Type: spec.MRoomCanonicalAlias, Type: spec.MRoomCanonicalAlias,
Content: map[string]interface{}{}, Content: map[string]interface{}{},
}) })
if resErr != nil { switch resErr.(type) {
if resErr.Code == api.PerformErrorNotAllowed { case api.ErrNotAllowed:
util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not set empty canonical alias event in old room") util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not set empty canonical alias event in old room")
} else { case nil:
return resErr return r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers)
} default:
} else { return resErr
if resErr = r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil {
return resErr
}
} }
return nil return nil
} }
func (r *Upgrader) publishIfOldRoomWasPublic(ctx context.Context, roomID, newRoomID string) *api.PerformError { func (r *Upgrader) publishIfOldRoomWasPublic(ctx context.Context, roomID, newRoomID string) error {
// check if the old room was published // check if the old room was published
var pubQueryRes api.QueryPublishedRoomsResponse var pubQueryRes api.QueryPublishedRoomsResponse
err := r.URSAPI.QueryPublishedRooms(ctx, &api.QueryPublishedRoomsRequest{ err := r.URSAPI.QueryPublishedRooms(ctx, &api.QueryPublishedRoomsRequest{
RoomID: roomID, RoomID: roomID,
}, &pubQueryRes) }, &pubQueryRes)
if err != nil { if err != nil {
return &api.PerformError{ return err
Msg: "QueryPublishedRooms failed",
}
} }
// if the old room is published (was public), publish the new room // if the old room is published (was public), publish the new room
@ -294,38 +254,27 @@ func publishNewRoomAndUnpublishOldRoom(
oldRoomID, newRoomID string, oldRoomID, newRoomID string,
) { ) {
// expose this room in the published room list // expose this room in the published room list
var pubNewRoomRes api.PerformPublishResponse
if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: newRoomID, RoomID: newRoomID,
Visibility: "public", Visibility: spec.Public,
}, &pubNewRoomRes); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if pubNewRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point // treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public") util.GetLogger(ctx).WithError(err).Error("failed to publish room")
} }
var unpubOldRoomRes api.PerformPublishResponse
// remove the old room from the published room list // remove the old room from the published room list
if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: oldRoomID, RoomID: oldRoomID,
Visibility: "private", Visibility: "private",
}, &unpubOldRoomRes); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if unpubOldRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point // treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private") util.GetLogger(ctx).WithError(err).Error("failed to un-publish room")
} }
} }
func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error { func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} if _, err := r.URSAPI.QueryRoomVersionForRoom(ctx, roomID); err != nil {
verRes := api.QueryRoomVersionForRoomResponse{} return eventutil.ErrRoomNoExists
if err := r.URSAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
return &api.PerformError{
Code: api.PerformErrorNoRoom,
Msg: "Room does not exist",
}
} }
return nil return nil
} }
@ -349,15 +298,15 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
} }
// nolint:gocyclo // nolint:gocyclo
func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID, newVersion string, tombstoneEvent *gomatrixserverlib.HeaderedEvent) ([]fledglingEvent, *api.PerformError) { func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]fledglingEvent, error) {
state := make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(oldRoom.StateEvents)) state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents))
for _, event := range oldRoom.StateEvents { for _, event := range oldRoom.StateEvents {
if event.StateKey() == nil { if event.StateKey() == nil {
// This shouldn't ever happen, but better to be safe than sorry. // This shouldn't ever happen, but better to be safe than sorry.
continue continue
} }
if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) { if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) {
// With the exception of bans and invites which we do want to copy, we // With the exception of bans which we do want to copy, we
// should ignore membership events that aren't our own, as event auth will // should ignore membership events that aren't our own, as event auth will
// prevent us from being able to create membership events on behalf of other // prevent us from being able to create membership events on behalf of other
// users anyway unless they are invites or bans. // users anyway unless they are invites or bans.
@ -367,11 +316,15 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
} }
switch membership { switch membership {
case spec.Ban: case spec.Ban:
case spec.Invite:
default: default:
continue continue
} }
} }
// skip events that rely on a specific user being present
sKey := *event.StateKey()
if event.Type() != spec.MRoomMember && len(sKey) > 0 && sKey[:1] == "@" {
continue
}
state[gomatrixserverlib.StateKeyTuple{EventType: event.Type(), StateKey: *event.StateKey()}] = event state[gomatrixserverlib.StateKeyTuple{EventType: event.Type(), StateKey: *event.StateKey()}] = event
} }
@ -388,9 +341,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
// old room state. Check that they are there. // old room state. Check that they are there.
for tuple := range override { for tuple := range override {
if _, ok := state[tuple]; !ok { if _, ok := state[tuple]; !ok {
return nil, &api.PerformError{ return nil, fmt.Errorf("essential event of type %q state key %q is missing", tuple.EventType, tuple.StateKey)
Msg: fmt.Sprintf("Essential event of type %q state key %q is missing", tuple.EventType, tuple.StateKey),
}
} }
} }
@ -437,9 +388,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
powerLevelContent, err := oldPowerLevelsEvent.PowerLevels() powerLevelContent, err := oldPowerLevelsEvent.PowerLevels()
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error() util.GetLogger(ctx).WithError(err).Error()
return nil, &api.PerformError{ return nil, fmt.Errorf("Power level event content was invalid")
Msg: "Power level event content was invalid",
}
} }
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, userID) tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, userID)
@ -503,9 +452,9 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
return eventsToMake, nil return eventsToMake, nil
} }
func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []fledglingEvent) error {
var err error var err error
var builtEvents []*gomatrixserverlib.HeaderedEvent var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
for i, e := range eventsToMake { for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
@ -519,34 +468,27 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
} }
err = builder.SetContent(e.Content) err = builder.SetContent(e.Content)
if err != nil { if err != nil {
return &api.PerformError{ return fmt.Errorf("failed to set content of new %q event: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to set content of new %q event: %s", builder.Type, err),
}
} }
if i > 0 { if i > 0 {
builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()}
} }
var event *gomatrixserverlib.Event var event *gomatrixserverlib.Event
event, err = builder.AddAuthEventsAndBuild(userDomain, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion), r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey) event, err = builder.AddAuthEventsAndBuild(userDomain, &authEvents, evTime, newVersion, r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey)
if err != nil { if err != nil {
return &api.PerformError{ return fmt.Errorf("failed to build new %q event: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err),
}
} }
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
return &api.PerformError{ return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to auth new %q event: %s", builder.Type, err),
}
} }
// Add the event to the list of auth events // Add the event to the list of auth events
builtEvents = append(builtEvents, event.Headered(gomatrixserverlib.RoomVersion(newVersion))) builtEvents = append(builtEvents, &types.HeaderedEvent{Event: event})
err = authEvents.AddEvent(event) err = authEvents.AddEvent(event)
if err != nil { if err != nil {
return &api.PerformError{ return fmt.Errorf("failed to add new %q event to auth set: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to add new %q event to auth set: %s", builder.Type, err),
}
} }
} }
@ -560,9 +502,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
}) })
} }
if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil { if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil {
return &api.PerformError{ return fmt.Errorf("failed to send new room %q to roomserver: %w", newRoomID, err)
Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err),
}
} }
return nil return nil
} }
@ -571,7 +511,7 @@ func (r *Upgrader) makeTombstoneEvent(
ctx context.Context, ctx context.Context,
evTime time.Time, evTime time.Time,
userID, roomID, newRoomID string, userID, roomID, newRoomID string,
) (*gomatrixserverlib.HeaderedEvent, *api.PerformError) { ) (*types.HeaderedEvent, error) {
content := map[string]interface{}{ content := map[string]interface{}{
"body": "This room has been replaced", "body": "This room has been replaced",
"replacement_room": newRoomID, "replacement_room": newRoomID,
@ -583,7 +523,7 @@ func (r *Upgrader) makeTombstoneEvent(
return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event) return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event)
} }
func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event fledglingEvent) (*gomatrixserverlib.HeaderedEvent, *api.PerformError) { func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event fledglingEvent) (*types.HeaderedEvent, error) {
builder := gomatrixserverlib.EventBuilder{ builder := gomatrixserverlib.EventBuilder{
Sender: userID, Sender: userID,
RoomID: roomID, RoomID: roomID,
@ -592,59 +532,36 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
} }
err := builder.SetContent(event.Content) err := builder.SetContent(event.Content)
if err != nil { if err != nil {
return nil, &api.PerformError{ return nil, fmt.Errorf("failed to set new %q event content: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err),
}
} }
// Get the sender domain. // Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender) _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender)
if serr != nil { if serr != nil {
return nil, &api.PerformError{ return nil, fmt.Errorf("Failed to split user ID %q: %w", builder.Sender, err)
Msg: fmt.Sprintf("Failed to split user ID %q: %s", builder.Sender, err),
}
} }
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
return nil, &api.PerformError{ return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err)
Msg: fmt.Sprintf("Failed to get signing identity for %q: %s", senderDomain, err),
}
} }
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes) headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &api.PerformError{ return nil, err
Code: api.PerformErrorNoRoom,
Msg: "Room does not exist",
}
} else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok {
return nil, &api.PerformError{ return nil, e
Msg: e.Error(),
}
} else if e, ok := err.(gomatrixserverlib.EventValidationError); ok { } else if e, ok := err.(gomatrixserverlib.EventValidationError); ok {
if e.Code == gomatrixserverlib.EventValidationTooLarge { return nil, e
return nil, &api.PerformError{
Msg: e.Error(),
}
}
return nil, &api.PerformError{
Msg: e.Error(),
}
} else if err != nil { } else if err != nil {
return nil, &api.PerformError{ return nil, fmt.Errorf("failed to build new %q event: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err),
}
} }
// check to see if this user can perform this operation // check to see if this user can perform this operation
stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents))
for i := range queryRes.StateEvents { for i := range queryRes.StateEvents {
stateEvents[i] = queryRes.StateEvents[i].Event stateEvents[i] = queryRes.StateEvents[i].Event
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(headeredEvent.Event, &provider); err != nil { if err = gomatrixserverlib.Allowed(headeredEvent.Event, &provider); err != nil {
return nil, &api.PerformError{ return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", builder.Type, err)} // TODO: Is this error string comprehensible to the client?
Code: api.PerformErrorNotAllowed,
Msg: fmt.Sprintf("Failed to auth new %q event: %s", builder.Type, err), // TODO: Is this error string comprehensible to the client?
}
} }
return headeredEvent, nil return headeredEvent, nil
@ -690,9 +607,9 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC
func (r *Upgrader) sendHeaderedEvent( func (r *Upgrader) sendHeaderedEvent(
ctx context.Context, ctx context.Context,
serverName spec.ServerName, serverName spec.ServerName,
headeredEvent *gomatrixserverlib.HeaderedEvent, headeredEvent *types.HeaderedEvent,
sendAsServer string, sendAsServer string,
) *api.PerformError { ) error {
var inputs []api.InputRoomEvent var inputs []api.InputRoomEvent
inputs = append(inputs, api.InputRoomEvent{ inputs = append(inputs, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
@ -700,11 +617,5 @@ func (r *Upgrader) sendHeaderedEvent(
Origin: serverName, Origin: serverName,
SendAsServer: sendAsServer, SendAsServer: sendAsServer,
}) })
if err := api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false); err != nil { return api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false)
return &api.PerformError{
Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err),
}
}
return nil
} }

View file

@ -121,14 +121,17 @@ func (r *Queryer) QueryStateAfterEvents(
return fmt.Errorf("getAuthChain: %w", err) return fmt.Errorf("getAuthChain: %w", err)
} }
stateEvents, err = gomatrixserverlib.ResolveConflicts(info.RoomVersion, stateEvents, authEvents) stateEventsPDU, err := gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
)
if err != nil { if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
} }
stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU)
} }
for _, event := range stateEvents { for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event})
} }
return nil return nil
@ -173,7 +176,7 @@ func (r *Queryer) QueryEventsByID(
} }
for _, event := range events { for _, event := range events {
response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion)) response.Events = append(response.Events, &types.HeaderedEvent{Event: event.Event})
} }
return nil return nil
@ -231,7 +234,7 @@ func (r *Queryer) QueryMembershipAtEvent(
request *api.QueryMembershipAtEventRequest, request *api.QueryMembershipAtEventRequest,
response *api.QueryMembershipAtEventResponse, response *api.QueryMembershipAtEventResponse,
) error { ) error {
response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) response.Membership = make(map[string]*types.HeaderedEvent)
info, err := r.DB.RoomInfo(ctx, request.RoomID) info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil { if err != nil {
@ -259,7 +262,7 @@ func (r *Queryer) QueryMembershipAtEvent(
return err return err
} }
response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) response.Membership = make(map[string]*types.HeaderedEvent)
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID])
if err != nil { if err != nil {
return fmt.Errorf("unable to get state before event: %w", err) return fmt.Errorf("unable to get state before event: %w", err)
@ -307,7 +310,7 @@ func (r *Queryer) QueryMembershipAtEvent(
for i := range memberships { for i := range memberships {
ev := memberships[i] ev := memberships[i]
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) {
response.Membership[eventID] = ev.Event.Headered(info.RoomVersion) response.Membership[eventID] = &types.HeaderedEvent{Event: ev.Event}
} }
} }
} }
@ -518,17 +521,13 @@ func (r *Queryer) QueryMissingEvents(
return err return err
} }
response.Events = make([]*gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) response.Events = make([]*types.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter))
for _, event := range loadedEvents { for _, event := range loadedEvents {
if !eventsToFilter[event.EventID()] { if !eventsToFilter[event.EventID()] {
roomVersion, verr := r.roomVersion(event.RoomID())
if verr != nil {
return verr
}
if _, ok := redactEventIDs[event.EventID()]; ok { if _, ok := redactEventIDs[event.EventID()]; ok {
event.Redact() event.Redact()
} }
response.Events = append(response.Events, event.Headered(roomVersion)) response.Events = append(response.Events, &types.HeaderedEvent{Event: event})
} }
} }
@ -561,7 +560,7 @@ func (r *Queryer) QueryStateAndAuthChain(
return err return err
} }
for _, event := range authEvents { for _, event := range authEvents {
response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion)) response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{Event: event})
} }
return nil return nil
} }
@ -589,19 +588,21 @@ func (r *Queryer) QueryStateAndAuthChain(
} }
if request.ResolveState { if request.ResolveState {
if stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEventsPDU, err2 := gomatrixserverlib.ResolveConflicts(
info.RoomVersion, stateEvents, authEvents, info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
); err != nil { )
return err if err2 != nil {
return err2
} }
stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU)
} }
for _, event := range stateEvents { for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event})
} }
for _, event := range authEvents { for _, event := range authEvents {
response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion)) response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{Event: event})
} }
return err return err
@ -696,34 +697,20 @@ func GetAuthChain(
} }
// QueryRoomVersionForRoom implements api.RoomserverInternalAPI // QueryRoomVersionForRoom implements api.RoomserverInternalAPI
func (r *Queryer) QueryRoomVersionForRoom( func (r *Queryer) QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) {
ctx context.Context, if roomVersion, ok := r.Cache.GetRoomVersion(roomID); ok {
request *api.QueryRoomVersionForRoomRequest, return roomVersion, nil
response *api.QueryRoomVersionForRoomResponse,
) error {
if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok {
response.RoomVersion = roomVersion
return nil
} }
info, err := r.DB.RoomInfo(ctx, request.RoomID) info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
return err return "", err
} }
if info == nil { if info == nil {
return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID) return "", fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", roomID)
} }
response.RoomVersion = info.RoomVersion r.Cache.StoreRoomVersion(roomID, info.RoomVersion)
r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) return info.RoomVersion, nil
return nil
}
func (r *Queryer) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) {
var res api.QueryRoomVersionForRoomResponse
err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{
RoomID: roomID,
}, &res)
return res.RoomVersion, err
} }
func (r *Queryer) QueryPublishedRooms( func (r *Queryer) QueryPublishedRooms(
@ -748,7 +735,7 @@ func (r *Queryer) QueryPublishedRooms(
} }
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent)
for _, tuple := range req.StateTuples { for _, tuple := range req.StateTuples {
if tuple.StateKey == "*" && req.AllowWildcards { if tuple.StateKey == "*" && req.AllowWildcards {
events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType) events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType)
@ -865,9 +852,9 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
if err != nil { if err != nil {
return err return err
} }
hchain := make([]*gomatrixserverlib.HeaderedEvent, len(chain)) hchain := make([]*types.HeaderedEvent, len(chain))
for i := range chain { for i := range chain {
hchain[i] = chain[i].Headered(chain[i].Version()) hchain[i] = &types.HeaderedEvent{Event: chain[i]}
} }
res.AuthChain = hchain res.AuthChain = hchain
return nil return nil
@ -910,8 +897,8 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil { if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err) return fmt.Errorf("json.Unmarshal: %w", err)
} }
// If the join rule isn't "restricted" then there's nothing more to do. // If the join rule isn't "restricted" or "knock_restricted" then there's nothing more to do.
res.Restricted = joinRules.JoinRule == spec.Restricted res.Restricted = joinRules.JoinRule == spec.Restricted || joinRules.JoinRule == spec.KnockRestricted
if !res.Restricted { if !res.Restricted {
return nil return nil
} }
@ -932,9 +919,9 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetStateEvent: %w", err) return fmt.Errorf("r.DB.GetStateEvent: %w", err)
} }
var powerLevels gomatrixserverlib.PowerLevelContent powerLevels, err := powerLevelsEvent.PowerLevels()
if err = json.Unmarshal(powerLevelsEvent.Content(), &powerLevels); err != nil { if err != nil {
return fmt.Errorf("json.Unmarshal: %w", err) return fmt.Errorf("unable to get powerlevels: %w", err)
} }
// Step through the join rules and see if the user matches any of them. // Step through the join rules and see if the user matches any of them.
for _, rule := range joinRules.Allow { for _, rule := range joinRules.Allow {

View file

@ -74,7 +74,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu
} }
if eventType == "m.room.server_acl" && update.NewRoomEvent.Event.StateKeyEquals("") { if eventType == "m.room.server_acl" && update.NewRoomEvent.Event.StateKeyEquals("") {
ev := update.NewRoomEvent.Event.Unwrap() ev := update.NewRoomEvent.Event.Event
defer r.ACLs.OnServerACLUpdate(ev) defer r.ACLs.OnServerACLUpdate(ev)
} }
} }

View file

@ -8,10 +8,13 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -136,7 +139,7 @@ func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI
// revoke guest access // revoke guest access
revokeEvent := room.CreateAndInsert(t, alice, spec.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) revokeEvent := room.CreateAndInsert(t, alice, spec.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey(""))
if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err) t.Errorf("failed to send events: %v", err)
} }
@ -242,8 +245,8 @@ func TestPurgeRoom(t *testing.T) {
// this starts the JetStream consumers // this starts the JetStream consumers
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics)
federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true)
rsAPI.SetFederationAPI(nil, nil) rsAPI.SetFederationAPI(fsAPI, nil)
// Create the room // Create the room
if err = api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { if err = api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
@ -251,13 +254,9 @@ func TestPurgeRoom(t *testing.T) {
} }
// some dummy entries to validate after purging // some dummy entries to validate after purging
publishResp := &api.PerformPublishResponse{} if err = rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: spec.Public}); err != nil {
if err = rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if publishResp.Error != nil {
t.Fatal(publishResp.Error)
}
isPublished, err := db.GetPublishedRoom(ctx, room.ID) isPublished, err := db.GetPublishedRoom(ctx, room.ID)
if err != nil { if err != nil {
@ -325,8 +324,7 @@ func TestPurgeRoom(t *testing.T) {
} }
// purge the room from the database // purge the room from the database
purgeResp := &api.PerformAdminPurgeRoomResponse{} if err = rsAPI.PerformAdminPurgeRoom(ctx, room.ID); err != nil {
if err = rsAPI.PerformAdminPurgeRoom(ctx, &api.PerformAdminPurgeRoomRequest{RoomID: room.ID}, purgeResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -401,7 +399,7 @@ type fledglingEvent struct {
PrevEvents []interface{} PrevEvents []interface{}
} }
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEvent) {
t.Helper() t.Helper()
roomVer := gomatrixserverlib.RoomVersionV9 roomVer := gomatrixserverlib.RoomVersionV9
seed := make([]byte, ed25519.SeedSize) // zero seed seed := make([]byte, ed25519.SeedSize) // zero seed
@ -423,7 +421,7 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib
if err != nil { if err != nil {
t.Fatalf("mustCreateEvent: failed to sign event: %s", err) t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
} }
h := signedEvent.Headered(roomVer) h := &types.HeaderedEvent{Event: signedEvent}
return h return h
} }
@ -451,7 +449,7 @@ func TestRedaction(t *testing.T) {
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
PrevEvents: []interface{}{redactedEvent.EventID()}, PrevEvents: []interface{}{redactedEvent.EventID()},
}) })
room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) room.InsertEvent(t, builderEv)
}, },
}, },
{ {
@ -468,7 +466,7 @@ func TestRedaction(t *testing.T) {
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
PrevEvents: []interface{}{redactedEvent.EventID()}, PrevEvents: []interface{}{redactedEvent.EventID()},
}) })
room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) room.InsertEvent(t, builderEv)
}, },
}, },
{ {
@ -485,7 +483,7 @@ func TestRedaction(t *testing.T) {
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
PrevEvents: []interface{}{redactedEvent.EventID()}, PrevEvents: []interface{}{redactedEvent.EventID()},
}) })
room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) room.InsertEvent(t, builderEv)
}, },
}, },
{ {
@ -501,7 +499,7 @@ func TestRedaction(t *testing.T) {
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
PrevEvents: []interface{}{redactedEvent.EventID()}, PrevEvents: []interface{}{redactedEvent.EventID()},
}) })
room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) room.InsertEvent(t, builderEv)
}, },
}, },
} }
@ -582,3 +580,506 @@ func TestRedaction(t *testing.T) {
} }
}) })
} }
func TestQueryRestrictedJoinAllowed(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
// a room we don't create in the database
allowedByRoomNotExists := test.NewRoom(t, alice)
// a room we create in the database, used for authorisation
allowedByRoomExists := test.NewRoom(t, alice)
allowedByRoomExists.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{
"membership": spec.Join,
}, test.WithStateKey(bob.ID))
testCases := []struct {
name string
prepareRoomFunc func(t *testing.T) *test.Room
wantResponse api.QueryRestrictedJoinAllowedResponse
}{
{
name: "public room unrestricted",
prepareRoomFunc: func(t *testing.T) *test.Room {
return test.NewRoom(t, alice)
},
wantResponse: api.QueryRestrictedJoinAllowedResponse{
Resident: true,
},
},
{
name: "room version without restrictions",
prepareRoomFunc: func(t *testing.T) *test.Room {
return test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV7))
},
},
{
name: "restricted only", // bob is not allowed to join
prepareRoomFunc: func(t *testing.T) *test.Room {
r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV8))
r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
"join_rule": spec.Restricted,
}, test.WithStateKey(""))
return r
},
wantResponse: api.QueryRestrictedJoinAllowedResponse{
Resident: true,
Restricted: true,
},
},
{
name: "knock_restricted",
prepareRoomFunc: func(t *testing.T) *test.Room {
r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV8))
r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
"join_rule": spec.KnockRestricted,
}, test.WithStateKey(""))
return r
},
wantResponse: api.QueryRestrictedJoinAllowedResponse{
Resident: true,
Restricted: true,
},
},
{
name: "restricted with pending invite", // bob should be allowed to join
prepareRoomFunc: func(t *testing.T) *test.Room {
r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV8))
r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
"join_rule": spec.Restricted,
}, test.WithStateKey(""))
r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
"membership": spec.Invite,
}, test.WithStateKey(bob.ID))
return r
},
wantResponse: api.QueryRestrictedJoinAllowedResponse{
Resident: true,
Restricted: true,
Allowed: true,
},
},
{
name: "restricted with allowed room_id, but missing room", // bob should not be allowed to join, as we don't know about the room
prepareRoomFunc: func(t *testing.T) *test.Room {
r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV10))
r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
"join_rule": spec.KnockRestricted,
"allow": []map[string]interface{}{
{
"room_id": allowedByRoomNotExists.ID,
"type": spec.MRoomMembership,
},
},
}, test.WithStateKey(""))
r.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{
"membership": spec.Join,
"join_authorised_via_users_server": alice.ID,
}, test.WithStateKey(bob.ID))
return r
},
wantResponse: api.QueryRestrictedJoinAllowedResponse{
Restricted: true,
},
},
{
name: "restricted with allowed room_id", // bob should be allowed to join, as we know about the room
prepareRoomFunc: func(t *testing.T) *test.Room {
r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV10))
r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
"join_rule": spec.KnockRestricted,
"allow": []map[string]interface{}{
{
"room_id": allowedByRoomExists.ID,
"type": spec.MRoomMembership,
},
},
}, test.WithStateKey(""))
r.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{
"membership": spec.Join,
"join_authorised_via_users_server": alice.ID,
}, test.WithStateKey(bob.ID))
return r
},
wantResponse: api.QueryRestrictedJoinAllowedResponse{
Resident: true,
Restricted: true,
Allowed: true,
AuthorisedVia: alice.ID,
},
},
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
natsInstance := jetstream.NATSInstance{}
defer close()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.prepareRoomFunc == nil {
t.Fatal("missing prepareRoomFunc")
}
testRoom := tc.prepareRoomFunc(t)
// Create the room
if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, testRoom.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, allowedByRoomExists.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
req := api.QueryRestrictedJoinAllowedRequest{
UserID: bob.ID,
RoomID: testRoom.ID,
}
res := api.QueryRestrictedJoinAllowedResponse{}
if err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), &req, &res); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(tc.wantResponse, res) {
t.Fatalf("unexpected response, want %#v - got %#v", tc.wantResponse, res)
}
})
}
})
}
func TestUpgrade(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := test.NewUser(t)
ctx := context.Background()
spaceChild := test.NewRoom(t, alice)
validateTuples := []gomatrixserverlib.StateKeyTuple{
{EventType: spec.MRoomCreate},
{EventType: spec.MRoomPowerLevels},
{EventType: spec.MRoomJoinRules},
{EventType: spec.MRoomName},
{EventType: spec.MRoomCanonicalAlias},
{EventType: "m.room.tombstone"},
{EventType: "m.custom.event"},
{EventType: "m.space.child", StateKey: spaceChild.ID},
{EventType: "m.custom.event", StateKey: alice.ID},
{EventType: spec.MRoomMember, StateKey: charlie.ID}, // ban should be transferred
}
validate := func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) {
oldRoomState := &api.QueryCurrentStateResponse{}
if err := rsAPI.QueryCurrentState(ctx, &api.QueryCurrentStateRequest{
RoomID: oldRoomID,
StateTuples: validateTuples,
}, oldRoomState); err != nil {
t.Fatal(err)
}
newRoomState := &api.QueryCurrentStateResponse{}
if err := rsAPI.QueryCurrentState(ctx, &api.QueryCurrentStateRequest{
RoomID: newRoomID,
StateTuples: validateTuples,
}, newRoomState); err != nil {
t.Fatal(err)
}
// the old room should have a tombstone event
ev := oldRoomState.StateEvents[gomatrixserverlib.StateKeyTuple{EventType: "m.room.tombstone"}]
replacementRoom := gjson.GetBytes(ev.Content(), "replacement_room").Str
if replacementRoom != newRoomID {
t.Fatalf("tombstone event has replacement_room '%s', expected '%s'", replacementRoom, newRoomID)
}
// the new room should have a predecessor equal to the old room
ev = newRoomState.StateEvents[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate}]
predecessor := gjson.GetBytes(ev.Content(), "predecessor.room_id").Str
if predecessor != oldRoomID {
t.Fatalf("got predecessor room '%s', expected '%s'", predecessor, oldRoomID)
}
for _, tuple := range validateTuples {
// Skip create and powerlevel event (new room has e.g. predecessor event, old room has restricted powerlevels)
switch tuple.EventType {
case spec.MRoomCreate, spec.MRoomPowerLevels, spec.MRoomCanonicalAlias:
continue
}
oldEv, ok := oldRoomState.StateEvents[tuple]
if !ok {
t.Logf("skipping tuple %#v as it doesn't exist in the old room", tuple)
continue
}
newEv, ok := newRoomState.StateEvents[tuple]
if !ok {
t.Logf("skipping tuple %#v as it doesn't exist in the new room", tuple)
continue
}
if !reflect.DeepEqual(oldEv.Content(), newEv.Content()) {
t.Logf("OldEvent QueryCurrentState: %s", string(oldEv.Content()))
t.Logf("NewEvent QueryCurrentState: %s", string(newEv.Content()))
t.Errorf("event content mismatch")
}
}
}
testCases := []struct {
name string
upgradeUser string
roomFunc func(rsAPI api.RoomserverInternalAPI) string
validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI)
wantNewRoom bool
}{
{
name: "invalid userID",
upgradeUser: "!notvalid:test",
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
room := test.NewRoom(t, alice)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
return room.ID
},
},
{
name: "invalid roomID",
upgradeUser: alice.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
return "!doesnotexist:test"
},
},
{
name: "powerlevel too low",
upgradeUser: bob.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
room := test.NewRoom(t, alice)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
return room.ID
},
},
{
name: "successful upgrade on new room",
upgradeUser: alice.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
room := test.NewRoom(t, alice)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
return room.ID
},
wantNewRoom: true,
validateFunc: validate,
},
{
name: "successful upgrade on new room with other state events",
upgradeUser: alice.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
r := test.NewRoom(t, alice)
r.CreateAndInsert(t, alice, spec.MRoomName, map[string]interface{}{
"name": "my new name",
}, test.WithStateKey(""))
r.CreateAndInsert(t, alice, spec.MRoomCanonicalAlias, eventutil.CanonicalAliasContent{
Alias: "#myalias:test",
}, test.WithStateKey(""))
// this will be transferred
r.CreateAndInsert(t, alice, "m.custom.event", map[string]interface{}{
"random": "i should exist",
}, test.WithStateKey(""))
// the following will be ignored
r.CreateAndInsert(t, alice, "m.custom.event", map[string]interface{}{
"random": "i will be ignored",
}, test.WithStateKey(alice.ID))
if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
return r.ID
},
wantNewRoom: true,
validateFunc: validate,
},
{
name: "with published room",
upgradeUser: alice.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
r := test.NewRoom(t, alice)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: r.ID,
Visibility: spec.Public,
}); err != nil {
t.Fatal(err)
}
return r.ID
},
wantNewRoom: true,
validateFunc: func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) {
validate(t, oldRoomID, newRoomID, rsAPI)
// check that the new room is published
res := &api.QueryPublishedRoomsResponse{}
if err := rsAPI.QueryPublishedRooms(ctx, &api.QueryPublishedRoomsRequest{RoomID: newRoomID}, res); err != nil {
t.Fatal(err)
}
if len(res.RoomIDs) == 0 {
t.Fatalf("expected room to be published, but wasn't: %#v", res.RoomIDs)
}
},
},
{
name: "with alias",
upgradeUser: alice.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
r := test.NewRoom(t, alice)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
if err := rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{
RoomID: r.ID,
Alias: "#myroomalias:test",
}, &api.SetRoomAliasResponse{}); err != nil {
t.Fatal(err)
}
return r.ID
},
wantNewRoom: true,
validateFunc: func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) {
validate(t, oldRoomID, newRoomID, rsAPI)
// check that the old room has no aliases
res := &api.GetAliasesForRoomIDResponse{}
if err := rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: oldRoomID}, res); err != nil {
t.Fatal(err)
}
if len(res.Aliases) != 0 {
t.Fatalf("expected old room aliases to be empty, but wasn't: %#v", res.Aliases)
}
// check that the new room has aliases
if err := rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: newRoomID}, res); err != nil {
t.Fatal(err)
}
if len(res.Aliases) == 0 {
t.Fatalf("expected room aliases to be transferred, but wasn't: %#v", res.Aliases)
}
},
},
{
name: "bans are transferred",
upgradeUser: alice.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
r := test.NewRoom(t, alice)
r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
"membership": spec.Ban,
}, test.WithStateKey(charlie.ID))
if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
return r.ID
},
wantNewRoom: true,
validateFunc: validate,
},
{
name: "space childs are transferred",
upgradeUser: alice.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
r := test.NewRoom(t, alice)
r.CreateAndInsert(t, alice, "m.space.child", map[string]interface{}{}, test.WithStateKey(spaceChild.ID))
if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
return r.ID
},
wantNewRoom: true,
validateFunc: validate,
},
{
name: "custom state is not taken to the new room", // https://github.com/matrix-org/dendrite/issues/2912
upgradeUser: charlie.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV6))
// Bob and Charlie join
r.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{"membership": spec.Join}, test.WithStateKey(bob.ID))
r.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]interface{}{"membership": spec.Join}, test.WithStateKey(charlie.ID))
// make Charlie an admin so the room can be upgraded
r.CreateAndInsert(t, alice, spec.MRoomPowerLevels, gomatrixserverlib.PowerLevelContent{
Users: map[string]int64{
charlie.ID: 100,
},
}, test.WithStateKey(""))
// Alice creates a custom event
r.CreateAndInsert(t, alice, "m.custom.event", map[string]interface{}{
"random": "data",
}, test.WithStateKey(alice.ID))
r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{"membership": spec.Leave}, test.WithStateKey(alice.ID))
if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
return r.ID
},
wantNewRoom: true,
validateFunc: validate,
},
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
natsInstance := jetstream.NATSInstance{}
defer close()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
rsAPI.SetFederationAPI(nil, nil)
rsAPI.SetUserAPI(userAPI)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.roomFunc == nil {
t.Fatalf("missing roomFunc")
}
if tc.upgradeUser == "" {
tc.upgradeUser = alice.ID
}
roomID := tc.roomFunc(rsAPI)
newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, tc.upgradeUser, version.DefaultRoomVersion())
if err != nil && tc.wantNewRoom {
t.Fatal(err)
}
if tc.wantNewRoom && newRoomID == "" {
t.Fatalf("expected a new room, but the upgrade failed")
}
if !tc.wantNewRoom && newRoomID != "" {
t.Fatalf("expected no new room, but the upgrade succeeded")
}
if tc.validateFunc != nil {
tc.validateFunc(t, roomID, newRoomID, rsAPI)
}
})
}
})
}

View file

@ -996,7 +996,7 @@ func (v *StateResolution) resolveConflictsV2(
// For each conflicted event, we will add a new set of auth events. Auth // For each conflicted event, we will add a new set of auth events. Auth
// events may be duplicated across these sets but that's OK. // events may be duplicated across these sets but that's OK.
authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted)) authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted))
authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) authEvents := make([]gomatrixserverlib.PDU, 0, estimate*3)
gotAuthEvents := make(map[string]struct{}, estimate*3) gotAuthEvents := make(map[string]struct{}, estimate*3)
knownAuthEvents := make(map[string]types.Event, estimate*3) knownAuthEvents := make(map[string]types.Event, estimate*3)
@ -1046,7 +1046,7 @@ func (v *StateResolution) resolveConflictsV2(
gotAuthEvents = nil // nolint:ineffassign gotAuthEvents = nil // nolint:ineffassign
// Resolve the conflicts. // Resolve the conflicts.
resolvedEvents := func() []*gomatrixserverlib.Event { resolvedEvents := func() []gomatrixserverlib.PDU {
resolvedTrace, _ := internal.StartRegion(ctx, "StateResolution.ResolveStateConflictsV2") resolvedTrace, _ := internal.StartRegion(ctx, "StateResolution.ResolveStateConflictsV2")
defer resolvedTrace.EndRegion() defer resolvedTrace.EndRegion()
@ -1119,11 +1119,11 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func (v *StateResolution) loadStateEvents( func (v *StateResolution) loadStateEvents(
ctx context.Context, entries []types.StateEntry, ctx context.Context, entries []types.StateEntry,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { ) ([]gomatrixserverlib.PDU, map[string]types.StateEntry, error) {
trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateEvents") trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateEvents")
defer trace.EndRegion() defer trace.EndRegion()
result := make([]*gomatrixserverlib.Event, 0, len(entries)) result := make([]gomatrixserverlib.PDU, 0, len(entries))
eventEntries := make([]types.StateEntry, 0, len(entries)) eventEntries := make([]types.StateEntry, 0, len(entries))
eventNIDs := make(types.EventNIDs, 0, len(entries)) eventNIDs := make(types.EventNIDs, 0, len(entries))
for _, entry := range entries { for _, entry := range entries {
@ -1163,7 +1163,7 @@ type authEventLoader struct {
// loadAuthEvents loads all of the auth events for a given event recursively, // loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events. // along with a map that contains state entries for all of the auth events.
func (l *authEventLoader) loadAuthEvents( func (l *authEventLoader) loadAuthEvents(
ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ctx context.Context, roomInfo *types.RoomInfo, event gomatrixserverlib.PDU, eventMap map[string]types.Event,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()

View file

@ -77,7 +77,7 @@ type Database interface {
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error. // Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error.
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database // Returns an error if the there is an error talking to the database
// Returns a types.MissingEventError if the event IDs aren't in the database. // Returns a types.MissingEventError if the event IDs aren't in the database.
@ -139,7 +139,7 @@ type Database interface {
// not found. // not found.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
// Publish or unpublish a room from the room directory. // PerformPublish publishes or unpublishes a room from the room directory. Returns a database error, if any.
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
// Returns a list of room IDs for rooms which are published. // Returns a list of room IDs for rooms which are published.
GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error) GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error)
@ -151,8 +151,8 @@ type Database interface {
// GetStateEvent returns the state event of a given type for a given room with a given state key // GetStateEvent returns the state event of a given type for a given room with a given state key
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error)
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
@ -181,8 +181,8 @@ type Database interface {
// a membership of "leave" when calculating history visibility. // a membership of "leave" when calculating history visibility.
GetMembershipForHistoryVisibility( GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) ) (map[string]*types.HeaderedEvent, error)
GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
MaybeRedactEvent( MaybeRedactEvent(
@ -207,10 +207,10 @@ type RoomDatabase interface {
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
} }
type EventDatabase interface { type EventDatabase interface {
@ -230,5 +230,5 @@ type EventDatabase interface {
MaybeRedactEvent( MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, plResolver state.PowerLevelResolver, ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, plResolver state.PowerLevelResolver,
) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error)
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
} }

View file

@ -205,19 +205,19 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility( func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(
ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) { ) (map[string]*types.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt) stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt)
rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID) rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer rows.Close() // nolint: errcheck
result := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) result := make(map[string]*types.HeaderedEvent, len(eventIDs))
var evJson []byte var evJson []byte
var eventID string var eventID string
var membershipEventID string var membershipEventID string
knownEvents := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) knownEvents := make(map[string]*types.HeaderedEvent, len(eventIDs))
verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
@ -228,7 +228,7 @@ func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(
return nil, err return nil, err
} }
if len(evJson) == 0 { if len(evJson) == 0 {
result[eventID] = &gomatrixserverlib.HeaderedEvent{} result[eventID] = &types.HeaderedEvent{}
continue continue
} }
// If we already know this event, don't try to marshal the json again // If we already know this event, don't try to marshal the json again
@ -238,11 +238,11 @@ func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(
} }
event, err := verImpl.NewEventFromTrustedJSON(evJson, false) event, err := verImpl.NewEventFromTrustedJSON(evJson, false)
if err != nil { if err != nil {
result[eventID] = &gomatrixserverlib.HeaderedEvent{} result[eventID] = &types.HeaderedEvent{}
// not fatal // not fatal
continue continue
} }
he := event.Headered(roomInfo.RoomVersion) he := &types.HeaderedEvent{Event: event}
result[eventID] = he result[eventID] = he
knownEvents[membershipEventID] = he knownEvents[membershipEventID] = he
} }

View file

@ -63,7 +63,7 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
func (d *Database) GetMembershipForHistoryVisibility( func (d *Database) GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, ctx context.Context, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) { ) (map[string]*types.HeaderedEvent, error) {
return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...) return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...)
} }
@ -658,7 +658,7 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e
} }
// GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. // GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID.
func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (roomInfo *types.RoomInfo, err error) { func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (roomInfo *types.RoomInfo, err error) {
// Get the default room version. If the client doesn't supply a room_version // Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room. // then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
@ -669,13 +669,17 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserve
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
} }
if roomVersion == "" {
rv, ok := d.Cache.GetRoomVersion(event.RoomID()) roomNID, nidOK := d.Cache.GetRoomServerRoomNID(event.RoomID())
if ok { cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(event.RoomID())
roomVersion = rv // if we found both, the roomNID and version in our cache, no need to query the database
} if nidOK && versionOK {
return &types.RoomInfo{
RoomNID: roomNID,
RoomVersion: cachedRoomVersion,
}, nil
} }
var roomNID types.RoomNID
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion)
if err != nil { if err != nil {
@ -721,7 +725,7 @@ func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKe
} }
func (d *EventDatabase) StoreEvent( func (d *EventDatabase) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event, ctx context.Context, event gomatrixserverlib.PDU,
roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
authEventNIDs []types.EventNID, isRejected bool, authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.StateAtEvent, error) { ) (types.EventNID, types.StateAtEvent, error) {
@ -905,7 +909,7 @@ func (d *EventDatabase) assignStateKeyNID(
return eventStateKeyNID, err return eventStateKeyNID, err
} }
func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) (
gomatrixserverlib.RoomVersion, error, gomatrixserverlib.RoomVersion, error,
) { ) {
var err error var err error
@ -1152,7 +1156,7 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type
// GetStateEvent returns the current state event of a given type for a given room with a given state key // GetStateEvent returns the current state event of a given type for a given room with a given state key
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) {
roomInfo, err := d.roomInfo(ctx, nil, roomID) roomInfo, err := d.roomInfo(ctx, nil, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1164,7 +1168,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
if roomInfo.IsStub() { if roomInfo.IsStub() {
return nil, nil return nil, nil
} }
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) eventTypeNID, err := d.GetOrCreateEventTypeNID(ctx, evType)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// No rooms have an event of this type, otherwise we'd have an event type NID // No rooms have an event of this type, otherwise we'd have an event type NID
return nil, nil return nil, nil
@ -1172,7 +1176,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
if err != nil { if err != nil {
return nil, err return nil, err
} }
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) stateKeyNID, err := d.GetOrCreateEventStateKeyNID(ctx, &stateKey)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// No rooms have a state event with this state key, otherwise we'd have an state key NID // No rooms have a state event with this state key, otherwise we'd have an state key NID
return nil, nil return nil, nil
@ -1201,6 +1205,10 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
// return the event requested // return the event requested
for _, e := range entries { for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
cachedEvent, ok := d.Cache.GetRoomServerEvent(e.EventNID)
if ok {
return &types.HeaderedEvent{Event: cachedEvent}, nil
}
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID}) data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
if err != nil { if err != nil {
return nil, err return nil, err
@ -1212,7 +1220,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ev.Headered(roomInfo.RoomVersion), nil return &types.HeaderedEvent{Event: ev}, nil
} }
} }
@ -1221,7 +1229,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error // Same as GetStateEvent but returns all matching state events with this event type. Returns no error
// if there are no events with this event type. // if there are no events with this event type.
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error) {
roomInfo, err := d.roomInfo(ctx, nil, roomID) roomInfo, err := d.roomInfo(ctx, nil, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1267,13 +1275,13 @@ func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evTy
if err != nil { if err != nil {
return nil, err return nil, err
} }
var result []*gomatrixserverlib.HeaderedEvent var result []*types.HeaderedEvent
for _, pair := range eventPairs { for _, pair := range eventPairs {
ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false) ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
result = append(result, ev.Headered(roomInfo.RoomVersion)) result = append(result, &types.HeaderedEvent{Event: ev})
} }
return result, nil return result, nil
@ -1324,7 +1332,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which // we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
// isn't a failure. // isn't a failure.
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes) eventTypeNIDMap, err := d.eventTypeNIDs(ctx, nil, eventTypes)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err) return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
} }
@ -1401,7 +1409,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
EventType: ev.Type(), EventType: ev.Type(),
RoomID: ev.RoomID(), RoomID: ev.RoomID(),
StateKey: *ev.StateKey(), StateKey: *ev.StateKey(),
ContentValue: tables.ExtractContentValue(ev.Headered(roomVer)), ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{Event: ev}),
} }
} }

View file

@ -26,7 +26,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -153,7 +152,7 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
return nil, tables.OptimisationNotSupportedError return nil, tables.OptimisationNotSupportedError
} }
func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string) (map[string]*gomatrixserverlib.HeaderedEvent, error) { func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string) (map[string]*types.HeaderedEvent, error) {
return nil, tables.OptimisationNotSupportedError return nil, tables.OptimisationNotSupportedError
} }

View file

@ -95,7 +95,7 @@ type StateSnapshot interface {
BulkSelectMembershipForHistoryVisibility( BulkSelectMembershipForHistoryVisibility(
ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) ) (map[string]*types.HeaderedEvent, error)
} }
type StateBlock interface { type StateBlock interface {
@ -196,7 +196,7 @@ type StrippedEvent struct {
// ExtractContentValue from the given state event. For example, given an m.room.name event with: // ExtractContentValue from the given state event. For example, given an m.room.name event with:
// content: { name: "Foo" } // content: { name: "Foo" }
// this returns "Foo". // this returns "Foo".
func ExtractContentValue(ev *gomatrixserverlib.HeaderedEvent) string { func ExtractContentValue(ev *types.HeaderedEvent) string {
content := ev.Content() content := ev.Content()
key := "" key := ""
switch ev.Type() { switch ev.Type() {

View file

@ -0,0 +1,47 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"github.com/matrix-org/gomatrixserverlib"
)
// HeaderedEvent is an Event which serialises to the headered form, which includes
// _room_version and _event_id fields.
type HeaderedEvent struct {
*gomatrixserverlib.Event
Visibility gomatrixserverlib.HistoryVisibility
}
func (h *HeaderedEvent) MarshalJSON() ([]byte, error) {
return h.Event.ToHeaderedJSON()
}
func (j *HeaderedEvent) UnmarshalJSON(data []byte) error {
ev, err := gomatrixserverlib.NewEventFromHeaderedJSON(data, false)
if err != nil {
return err
}
j.Event = ev
return nil
}
func NewEventJSONsFromHeaderedEvents(hes []*HeaderedEvent) gomatrixserverlib.EventJSONs {
result := make(gomatrixserverlib.EventJSONs, len(hes))
for i := range hes {
result[i] = hes[i].JSON()
}
return result
}

View file

@ -33,6 +33,7 @@ import (
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
roomserver "github.com/matrix-org/dendrite/roomserver/api" roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -94,7 +95,7 @@ type MSC2836EventRelationshipsResponse struct {
func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse {
out := &EventRelationshipResponse{ out := &EventRelationshipResponse{
Events: synctypes.ToClientEvents(res.ParsedEvents, synctypes.FormatAll), Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll),
Limited: res.Limited, Limited: res.Limited,
NextBatch: res.NextBatch, NextBatch: res.NextBatch,
} }
@ -112,7 +113,7 @@ func Enable(
} }
hooks.Enable() hooks.Enable()
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) he := headeredEvent.(*types.HeaderedEvent)
hookErr := db.StoreRelation(context.Background(), he) hookErr := db.StoreRelation(context.Background(), he)
if hookErr != nil { if hookErr != nil {
util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error( util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
@ -255,7 +256,7 @@ func federatedEventRelationship(
func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONResponse) { func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONResponse) {
var res MSC2836EventRelationshipsResponse var res MSC2836EventRelationshipsResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent var returnEvents []*types.HeaderedEvent
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue. // Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID) event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID)
if event == nil { if event == nil {
@ -299,7 +300,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo
for _, ev := range returnEvents { for _, ev := range returnEvents {
included[ev.EventID()] = true included[ev.EventID()] = true
} }
var events []*gomatrixserverlib.HeaderedEvent var events []*types.HeaderedEvent
events, walkLimited = walkThread( events, walkLimited = walkThread(
rc.ctx, rc.db, rc, included, remaining, rc.ctx, rc.db, rc, included, remaining,
) )
@ -309,7 +310,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo
for i, ev := range returnEvents { for i, ev := range returnEvents {
// for each event, extract the children_count | hash and add it as unsigned data. // for each event, extract the children_count | hash and add it as unsigned data.
rc.addChildMetadata(ev) rc.addChildMetadata(ev)
res.ParsedEvents[i] = ev.Unwrap() res.ParsedEvents[i] = ev.Event
} }
res.Limited = remaining == 0 || walkLimited res.Limited = remaining == 0 || walkLimited
return &res, nil return &res, nil
@ -318,17 +319,15 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo
// fetchUnknownEvent retrieves an unknown event from the room specified. This server must // fetchUnknownEvent retrieves an unknown event from the room specified. This server must
// be joined to the room in question. This has the side effect of injecting surround threaded // be joined to the room in question. This has the side effect of injecting surround threaded
// events into the roomserver. // events into the roomserver.
func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *types.HeaderedEvent {
if rc.isFederatedRequest || roomID == "" { if rc.isFederatedRequest || roomID == "" {
// we don't do fed hits for fed requests, and we can't ask servers without a room ID! // we don't do fed hits for fed requests, and we can't ask servers without a room ID!
return nil return nil
} }
logger := util.GetLogger(rc.ctx).WithField("room_id", roomID) logger := util.GetLogger(rc.ctx).WithField("room_id", roomID)
// if they supplied a room_id, check the room exists. // if they supplied a room_id, check the room exists.
var queryVerRes roomserver.QueryRoomVersionForRoomResponse
err := rc.rsAPI.QueryRoomVersionForRoom(rc.ctx, &roomserver.QueryRoomVersionForRoomRequest{ roomVersion, err := rc.rsAPI.QueryRoomVersionForRoom(rc.ctx, roomID)
RoomID: roomID,
}, &queryVerRes)
if err != nil { if err != nil {
logger.WithError(err).Warn("failed to query room version for room, does this room exist?") logger.WithError(err).Warn("failed to query room version for room, does this room exist?")
return nil return nil
@ -367,14 +366,14 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.H
// Inject the response into the roomserver to remember the event across multiple calls and to set // Inject the response into the roomserver to remember the event across multiple calls and to set
// unexplored flags correctly. // unexplored flags correctly.
for _, srv := range serversToQuery { for _, srv := range serversToQuery {
res, err := rc.MSC2836EventRelationships(eventID, srv, queryVerRes.RoomVersion) res, err := rc.MSC2836EventRelationships(eventID, srv, roomVersion)
if err != nil { if err != nil {
continue continue
} }
rc.injectResponseToRoomserver(res) rc.injectResponseToRoomserver(res)
for _, ev := range res.ParsedEvents { for _, ev := range res.ParsedEvents {
if ev.EventID() == eventID { if ev.EventID() == eventID {
return ev.Headered(ev.Version()) return &types.HeaderedEvent{Event: ev}
} }
} }
} }
@ -384,7 +383,7 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.H
// If include_parent: true and there is a valid m.relationship field in the event, // If include_parent: true and there is a valid m.relationship field in the event,
// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. // retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array.
func (rc *reqCtx) includeParent(childEvent *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) { func (rc *reqCtx) includeParent(childEvent *types.HeaderedEvent) (parent *types.HeaderedEvent) {
parentID, _, _ := parentChildEventIDs(childEvent) parentID, _, _ := parentChildEventIDs(childEvent)
if parentID == "" { if parentID == "" {
return nil return nil
@ -395,7 +394,7 @@ func (rc *reqCtx) includeParent(childEvent *gomatrixserverlib.HeaderedEvent) (pa
// If include_children: true, lookup all events which have event_id as an m.relationship // If include_children: true, lookup all events which have event_id as an m.relationship
// Apply history visibility checks to all these events and add the ones which pass into the response array, // Apply history visibility checks to all these events and add the ones which pass into the response array,
// honouring the recent_first flag and the limit. // honouring the recent_first flag and the limit.
func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*types.HeaderedEvent, *util.JSONResponse) {
if rc.hasUnexploredChildren(parentID) { if rc.hasUnexploredChildren(parentID) {
// we need to do a remote request to pull in the children as we are missing them locally. // we need to do a remote request to pull in the children as we are missing them locally.
serversToQuery := rc.getServersForEventID(parentID) serversToQuery := rc.getServersForEventID(parentID)
@ -432,7 +431,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
resErr := jsonerror.InternalServerError() resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }
var childEvents []*gomatrixserverlib.HeaderedEvent var childEvents []*types.HeaderedEvent
for _, child := range children { for _, child := range children {
childEvent := rc.lookForEvent(child.EventID) childEvent := rc.lookForEvent(child.EventID)
if childEvent != nil { if childEvent != nil {
@ -449,8 +448,8 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
// honouring the limit, max_depth and max_breadth values according to the following rules // honouring the limit, max_depth and max_breadth values according to the following rules
func walkThread( func walkThread(
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
) ([]*gomatrixserverlib.HeaderedEvent, bool) { ) ([]*types.HeaderedEvent, bool) {
var result []*gomatrixserverlib.HeaderedEvent var result []*types.HeaderedEvent
eventWalker := walker{ eventWalker := walker{
ctx: ctx, ctx: ctx,
req: rc.req, req: rc.req,
@ -512,7 +511,7 @@ func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv spec.ServerName,
// authorisedToSeeEvent checks that the user or server is allowed to see this event. Returns true if allowed to // authorisedToSeeEvent checks that the user or server is allowed to see this event. Returns true if allowed to
// see this request. This only needs to be done once per room at present as we just check for joined status. // see this request. This only needs to be done once per room at present as we just check for joined status.
func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) bool { func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool {
if rc.isFederatedRequest { if rc.isFederatedRequest {
// make sure the server is in this room // make sure the server is in this room
var res fs.QueryJoinedHostServerNamesInRoomResponse var res fs.QueryJoinedHostServerNamesInRoomResponse
@ -595,7 +594,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelation
// lookForEvent returns the event for the event ID given, by trying to query remote servers // lookForEvent returns the event for the event ID given, by trying to query remote servers
// if the event ID is unknown via /event_relationships. // if the event ID is unknown via /event_relationships.
func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) lookForEvent(eventID string) *types.HeaderedEvent {
event := rc.getLocalEvent(rc.req.RoomID, eventID) event := rc.getLocalEvent(rc.req.RoomID, eventID)
if event == nil { if event == nil {
queryRes := rc.remoteEventRelationships(eventID) queryRes := rc.remoteEventRelationships(eventID)
@ -604,7 +603,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent
rc.injectResponseToRoomserver(queryRes) rc.injectResponseToRoomserver(queryRes)
for _, ev := range queryRes.ParsedEvents { for _, ev := range queryRes.ParsedEvents {
if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() { if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() {
return ev.Headered(ev.Version()) return &types.HeaderedEvent{Event: ev}
} }
} }
} }
@ -626,7 +625,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent
return nil return nil
} }
func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) getLocalEvent(roomID, eventID string) *types.HeaderedEvent {
var queryEventsRes roomserver.QueryEventsByIDResponse var queryEventsRes roomserver.QueryEventsByIDResponse
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
RoomID: roomID, RoomID: roomID,
@ -647,7 +646,7 @@ func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.Heade
// into the roomserver as KindOutlier, with auth chains. // into the roomserver as KindOutlier, with auth chains.
func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsResponse) { func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsResponse) {
var stateEvents gomatrixserverlib.EventJSONs var stateEvents gomatrixserverlib.EventJSONs
var messageEvents []*gomatrixserverlib.Event var messageEvents []gomatrixserverlib.PDU
for _, ev := range res.ParsedEvents { for _, ev := range res.ParsedEvents {
if ev.StateKey() != nil { if ev.StateKey() != nil {
stateEvents = append(stateEvents, ev.JSON()) stateEvents = append(stateEvents, ev.JSON())
@ -666,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
for _, outlier := range append(eventsInOrder, messageEvents...) { for _, outlier := range append(eventsInOrder, messageEvents...) {
ires = append(ires, roomserver.InputRoomEvent{ ires = append(ires, roomserver.InputRoomEvent{
Kind: roomserver.KindOutlier, Kind: roomserver.KindOutlier,
Event: outlier.Headered(outlier.Version()), Event: &types.HeaderedEvent{Event: outlier.(*gomatrixserverlib.Event)},
}) })
} }
// we've got the data by this point so use a background context // we've got the data by this point so use a background context
@ -685,7 +684,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
} }
} }
func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) { func (rc *reqCtx) addChildMetadata(ev *types.HeaderedEvent) {
count, hash := rc.getChildMetadata(ev.EventID()) count, hash := rc.getChildMetadata(ev.EventID())
if count == 0 { if count == 0 {
return return

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
roomserver "github.com/matrix-org/dendrite/roomserver/api" roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs/msc2836" "github.com/matrix-org/dendrite/setup/mscs/msc2836"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -171,7 +172,7 @@ func TestMSC2836(t *testing.T) {
bob: {roomID}, bob: {roomID},
charlie: {roomID}, charlie: {roomID},
}, },
events: map[string]*gomatrixserverlib.HeaderedEvent{ events: map[string]*types.HeaderedEvent{
eventA.EventID(): eventA, eventA.EventID(): eventA,
eventB.EventID(): eventB, eventB.EventID(): eventB,
eventC.EventID(): eventC, eventC.EventID(): eventC,
@ -182,7 +183,7 @@ func TestMSC2836(t *testing.T) {
eventH.EventID(): eventH, eventH.EventID(): eventH,
}, },
} }
router := injectEvents(t, nopUserAPI, nopRsAPI, []*gomatrixserverlib.HeaderedEvent{ router := injectEvents(t, nopUserAPI, nopRsAPI, []*types.HeaderedEvent{
eventA, eventB, eventC, eventD, eventE, eventF, eventG, eventH, eventA, eventB, eventC, eventD, eventE, eventF, eventG, eventH,
}) })
cancel := runServer(t, router) cancel := runServer(t, router)
@ -521,7 +522,7 @@ type testRoomserverAPI struct {
// We'll override the functions we care about. // We'll override the functions we care about.
roomserver.RoomserverInternalAPI roomserver.RoomserverInternalAPI
userToJoinedRooms map[string][]string userToJoinedRooms map[string][]string
events map[string]*gomatrixserverlib.HeaderedEvent events map[string]*types.HeaderedEvent
} }
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
@ -547,7 +548,7 @@ func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roo
return nil return nil
} }
func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router { func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*types.HeaderedEvent) *mux.Router {
t.Helper() t.Helper()
cfg := &config.Dendrite{} cfg := &config.Dendrite{}
cfg.Defaults(config.DefaultOpts{ cfg.Defaults(config.DefaultOpts{
@ -579,7 +580,7 @@ type fledglingEvent struct {
RoomID string RoomID string
} }
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEvent) {
t.Helper() t.Helper()
roomVer := gomatrixserverlib.RoomVersionV6 roomVer := gomatrixserverlib.RoomVersionV6
seed := make([]byte, ed25519.SeedSize) // zero seed seed := make([]byte, ed25519.SeedSize) // zero seed
@ -601,6 +602,6 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib
if err != nil { if err != nil {
t.Fatalf("mustCreateEvent: failed to sign event: %s", err) t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
} }
h := signedEvent.Headered(roomVer) h := &types.HeaderedEvent{Event: signedEvent}
return h return h
} }

View file

@ -8,8 +8,8 @@ import (
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -23,7 +23,7 @@ type eventInfo struct {
type Database interface { type Database interface {
// StoreRelation stores the parent->child and child->parent relationship for later querying. // StoreRelation stores the parent->child and child->parent relationship for later querying.
// Also stores the event metadata e.g timestamp // Also stores the event metadata e.g timestamp
StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error
// ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the // ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the
// provided `relType`. The returned slice is sorted by origin_server_ts according to whether // provided `relType`. The returned slice is sorted by origin_server_ts according to whether
// `recentFirst` is true or false. // `recentFirst` is true or false.
@ -35,7 +35,7 @@ type Database interface {
// UpdateChildMetadata persists the children_count and children_hash from this event if and only if // UpdateChildMetadata persists the children_count and children_hash from this event if and only if
// the count is greater than what was previously there. If the count is updated, the event will be // the count is greater than what was previously there. If the count is updated, the event will be
// updated to be unexplored. // updated to be unexplored.
UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error UpdateChildMetadata(ctx context.Context, ev *types.HeaderedEvent) error
// ChildMetadata returns the children_count and children_hash for the event ID in question. // ChildMetadata returns the children_count and children_hash for the event ID in question.
// Also returns the `explored` flag, which is set to true when MarkChildrenExplored is called and is set // Also returns the `explored` flag, which is set to true when MarkChildrenExplored is called and is set
// back to `false` when a larger count is inserted via UpdateChildMetadata. // back to `false` when a larger count is inserted via UpdateChildMetadata.
@ -222,7 +222,7 @@ func newSQLiteDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOption
return &d, nil return &d, nil
} }
func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { func (p *DB) StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error {
parent, child, relType := parentChildEventIDs(ev) parent, child, relType := parentChildEventIDs(ev)
if parent == "" || child == "" { if parent == "" || child == "" {
return nil return nil
@ -244,7 +244,7 @@ func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEv
}) })
} }
func (p *DB) UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { func (p *DB) UpdateChildMetadata(ctx context.Context, ev *types.HeaderedEvent) error {
eventCount, eventHash := extractChildMetadata(ev) eventCount, eventHash := extractChildMetadata(ev)
if eventCount == 0 { if eventCount == 0 {
return nil // nothing to update with return nil // nothing to update with
@ -315,7 +315,7 @@ func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*even
return &ei, nil return &ei, nil
} }
func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) { func parentChildEventIDs(ev *types.HeaderedEvent) (parent, child, relType string) {
if ev == nil { if ev == nil {
return return
} }
@ -334,7 +334,7 @@ func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, re
return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType
} }
func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, servers []string) { func roomIDAndServers(ev *types.HeaderedEvent) (roomID string, servers []string) {
servers = []string{} servers = []string{}
if ev == nil { if ev == nil {
return return
@ -349,7 +349,7 @@ func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, serve
return body.RoomID, body.Servers return body.RoomID, body.Servers
} }
func extractChildMetadata(ev *gomatrixserverlib.HeaderedEvent) (count int, hash []byte) { func extractChildMetadata(ev *types.HeaderedEvent) (count int, hash []byte) {
unsigned := struct { unsigned := struct {
Counts map[string]int `json:"children"` Counts map[string]int `json:"children"`
Hash spec.Base64Bytes `json:"children_hash"` Hash spec.Base64Bytes `json:"children_hash"`

View file

@ -33,6 +33,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api" roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -388,7 +389,7 @@ func (w *walker) walk() util.JSONResponse {
} }
} }
func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent { func (w *walker) stateEvent(roomID, evType, stateKey string) *types.HeaderedEvent {
var queryRes roomserver.QueryCurrentStateResponse var queryRes roomserver.QueryCurrentStateResponse
tuple := gomatrixserverlib.StateKeyTuple{ tuple := gomatrixserverlib.StateKeyTuple{
EventType: evType, EventType: evType,
@ -636,7 +637,7 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi
return false, false return false, false
} }
func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.HeaderedEvent, allowType string) (allows []string) { func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *types.HeaderedEvent, allowType string) (allows []string) {
rule, _ := joinRuleEv.JoinRule() rule, _ := joinRuleEv.JoinRule()
if rule != spec.Restricted { if rule != spec.Restricted {
return nil return nil

View file

@ -21,7 +21,6 @@ import (
"fmt" "fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -31,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
@ -318,7 +318,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
pduPos, err := s.db.WriteEvent( pduPos, err := s.db.WriteEvent(
ctx, ctx,
ev, ev,
[]*gomatrixserverlib.HeaderedEvent{}, []*rstypes.HeaderedEvent{},
[]string{}, // adds no state []string{}, // adds no state
[]string{}, // removes no state []string{}, // removes no state
nil, // no transaction nil, // no transaction
@ -362,7 +362,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
return nil return nil
} }
func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, sp types.StreamPosition) (types.StreamPosition, error) { func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rstypes.HeaderedEvent, sp types.StreamPosition) (types.StreamPosition, error) {
if ev.Type() != spec.MRoomMember { if ev.Type() != spec.MRoomMember {
return sp, nil return sp, nil
} }
@ -496,7 +496,7 @@ func (s *OutputRoomEventConsumer) onPurgeRoom(
} }
} }
func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.HeaderedEvent) (*gomatrixserverlib.HeaderedEvent, error) { func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) (*rstypes.HeaderedEvent, error) {
if event.StateKey() == nil { if event.StateKey() == nil {
return event, nil return event, nil
} }
@ -531,7 +531,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head
return event, err return event, err
} }
func (s *OutputRoomEventConsumer) writeFTS(ev *gomatrixserverlib.HeaderedEvent, pduPosition types.StreamPosition) error { func (s *OutputRoomEventConsumer) writeFTS(ev *rstypes.HeaderedEvent, pduPosition types.StreamPosition) error {
if !s.cfg.Fulltext.Enabled { if !s.cfg.Fulltext.Enabled {
return nil return nil
} }

Some files were not shown because too many files have changed in this diff Show more