diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 9aae34e38..92c283b52 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -36,6 +36,7 @@ import ( "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/mscs" "github.com/matrix-org/dendrite/signingkeyserver" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" @@ -130,6 +131,8 @@ func main() { cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-e2ekey.db", *instanceName)) + cfg.MSCs.MSCs = []string{"msc2836"} + cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) if err = cfg.Derive(); err != nil { panic(err) } @@ -190,6 +193,9 @@ func main() { base.Base.PublicKeyAPIMux, base.Base.PublicMediaAPIMux, ) + if err := mscs.Enable(&base.Base, &monolith); err != nil { + logrus.WithError(err).Fatalf("Failed to enable MSCs") + } httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.Base.InternalAPIMux) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 39643cc2f..16f92cfa1 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -39,6 +39,7 @@ import ( "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/mscs" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" @@ -83,6 +84,8 @@ func main() { cfg.FederationSender.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) + cfg.MSCs.MSCs = []string{"msc2836"} + cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) if err = cfg.Derive(); err != nil { panic(err) } @@ -151,6 +154,9 @@ func main() { base.PublicKeyAPIMux, base.PublicMediaAPIMux, ) + if err := mscs.Enable(base, &monolith); err != nil { + logrus.WithError(err).Fatalf("Failed to enable MSCs") + } httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 3233cf809..c4cb4abbb 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -63,6 +63,8 @@ func main() { if *defaultsForCI { cfg.ClientAPI.RateLimiting.Enabled = false cfg.FederationSender.DisableTLSValidation = true + cfg.MSCs.MSCs = []string{"msc2836"} + cfg.Logging[0].Level = "trace" } j, err := yaml.Marshal(cfg) diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index a9ee78830..8bdf54c4a 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -9,7 +9,6 @@ import ( "time" eduAPI "github.com/matrix-org/dendrite/eduserver/api" - fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" @@ -85,6 +84,7 @@ func (o *testEDUProducer) InputReceiptEvent( } type testRoomserverAPI struct { + api.RoomserverInternalAPITrace inputRoomEvents []api.InputRoomEvent queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse @@ -92,12 +92,6 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse } -func (t *testRoomserverAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, resp *api.PerformForgetResponse) error { - return nil -} - -func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} - func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, @@ -109,50 +103,6 @@ func (t *testRoomserverAPI) InputRoomEvents( } } -func (t *testRoomserverAPI) PerformInvite( - ctx context.Context, - req *api.PerformInviteRequest, - res *api.PerformInviteResponse, -) error { - return nil -} - -func (t *testRoomserverAPI) PerformJoin( - ctx context.Context, - req *api.PerformJoinRequest, - res *api.PerformJoinResponse, -) { -} - -func (t *testRoomserverAPI) PerformPeek( - ctx context.Context, - req *api.PerformPeekRequest, - res *api.PerformPeekResponse, -) { -} - -func (t *testRoomserverAPI) PerformUnpeek( - ctx context.Context, - req *api.PerformUnpeekRequest, - res *api.PerformUnpeekResponse, -) { -} - -func (t *testRoomserverAPI) PerformPublish( - ctx context.Context, - req *api.PerformPublishRequest, - res *api.PerformPublishResponse, -) { -} - -func (t *testRoomserverAPI) PerformLeave( - ctx context.Context, - req *api.PerformLeaveRequest, - res *api.PerformLeaveResponse, -) error { - return nil -} - // Query the latest events and state for a room from the room server. func (t *testRoomserverAPI) QueryLatestEventsAndState( ctx context.Context, diff --git a/federationsender/api/api.go b/federationsender/api/api.go index a4d15f1f5..e4d176b16 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -21,6 +21,7 @@ type FederationClient interface { QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) + MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) } diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go index a14cf3caa..407e7ffec 100644 --- a/federationsender/internal/api.go +++ b/federationsender/internal/api.go @@ -229,3 +229,18 @@ func (a *FederationSenderInternalAPI) LookupServerKeys( } return ires.([]gomatrixserverlib.ServerKeys), nil } + +func (a *FederationSenderInternalAPI) MSC2836EventRelationships( + ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) + }) + if err != nil { + return res, err + } + return ires.(gomatrixserverlib.MSC2836EventRelationshipsResponse), nil +} diff --git a/federationsender/inthttp/client.go b/federationsender/inthttp/client.go index e0783ee1b..fe98ff33d 100644 --- a/federationsender/inthttp/client.go +++ b/federationsender/inthttp/client.go @@ -23,15 +23,16 @@ const ( FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive" FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU" - FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices" - FederationSenderClaimKeysPath = "/federationsender/client/claimKeys" - FederationSenderQueryKeysPath = "/federationsender/client/queryKeys" - FederationSenderBackfillPath = "/federationsender/client/backfill" - FederationSenderLookupStatePath = "/federationsender/client/lookupState" - FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs" - FederationSenderGetEventPath = "/federationsender/client/getEvent" - FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys" - FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys" + FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices" + FederationSenderClaimKeysPath = "/federationsender/client/claimKeys" + FederationSenderQueryKeysPath = "/federationsender/client/queryKeys" + FederationSenderBackfillPath = "/federationsender/client/backfill" + FederationSenderLookupStatePath = "/federationsender/client/lookupState" + FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs" + FederationSenderGetEventPath = "/federationsender/client/getEvent" + FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys" + FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys" + FederationSenderEventRelationshipsPath = "/federationsender/client/msc2836eventRelationships" ) // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. @@ -416,3 +417,35 @@ func (h *httpFederationSenderInternalAPI) LookupServerKeys( } return response.ServerKeys, nil } + +type eventRelationships struct { + S gomatrixserverlib.ServerName + Req gomatrixserverlib.MSC2836EventRelationshipsRequest + RoomVer gomatrixserverlib.RoomVersion + Res gomatrixserverlib.MSC2836EventRelationshipsResponse + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) MSC2836EventRelationships( + ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships") + defer span.Finish() + + request := eventRelationships{ + S: s, + Req: r, + RoomVer: roomVersion, + } + var response eventRelationships + apiURL := h.federationSenderURL + FederationSenderEventRelationshipsPath + err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return res, err + } + if response.Err != nil { + return res, response.Err + } + return response.Res, nil +} diff --git a/federationsender/inthttp/server.go b/federationsender/inthttp/server.go index 53e1183e4..293fb4209 100644 --- a/federationsender/inthttp/server.go +++ b/federationsender/inthttp/server.go @@ -307,4 +307,26 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route return util.JSONResponse{Code: http.StatusOK, JSON: request} }), ) + internalAPIMux.Handle( + FederationSenderEventRelationshipsPath, + httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse { + var request eventRelationships + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) } diff --git a/go.mod b/go.mod index 54a139321..0d39cf03b 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gologme/log v1.2.0 github.com/gorilla/mux v1.8.0 github.com/hashicorp/golang-lru v0.5.4 + github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect github.com/lib/pq v1.8.0 github.com/libp2p/go-libp2p v0.11.0 github.com/libp2p/go-libp2p-circuit v0.3.1 @@ -32,15 +33,17 @@ require ( github.com/pkg/errors v0.9.1 github.com/pressly/goose v2.7.0-rc5+incompatible github.com/prometheus/client_golang v1.7.1 - github.com/sirupsen/logrus v1.6.0 + github.com/sirupsen/logrus v1.7.0 github.com/tidwall/gjson v1.6.3 - github.com/tidwall/sjson v1.1.1 + github.com/tidwall/match v1.0.2 // indirect + github.com/tidwall/sjson v1.1.2 github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-lib v2.2.0+incompatible github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20201006093556-760d9a7fd5ee go.uber.org/atomic v1.6.0 - golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a + golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 golang.org/x/net v0.0.0-20200528225125-3c3fba18258b + golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect gopkg.in/h2non/bimg.v1 v1.1.4 gopkg.in/yaml.v2 v2.3.0 ) diff --git a/go.sum b/go.sum index 9780a688d..3bd186ce9 100644 --- a/go.sum +++ b/go.sum @@ -781,6 +781,8 @@ github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4 github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/smola/gocompat v0.2.0/go.mod h1:1B0MlxbmoZNo3h8guHp8HztB3BSYR5itql9qtVc0ypY= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= @@ -812,10 +814,13 @@ github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpP github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= +github.com/tidwall/gjson v1.6.1/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0= github.com/tidwall/gjson v1.6.3 h1:aHoiiem0dr7GHkW001T1SMTJ7X5PvyekH5WX0whWGnI= github.com/tidwall/gjson v1.6.3/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0= github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc= github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= +github.com/tidwall/match v1.0.2 h1:uuqvHuBGSedK7awZ2YoAtpnimfwBGFjHuWLuLqQj+bU= +github.com/tidwall/match v1.0.2/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8= github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= @@ -825,6 +830,8 @@ github.com/tidwall/sjson v1.0.3 h1:DeF+0LZqvIt4fKYw41aPB29ZGlvwVkHKktoXJ1YW9Y8= github.com/tidwall/sjson v1.0.3/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= github.com/tidwall/sjson v1.1.1 h1:7h1vk049Jnd5EH9NyzNiEuwYW4b5qgreBbqRC19AS3U= github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs= +github.com/tidwall/sjson v1.1.2 h1:NC5okI+tQ8OG/oyzchvwXXxRxCV/FVdhODbPKkQ25jQ= +github.com/tidwall/sjson v1.1.2/go.mod h1:SEzaDwxiPzKzNfUEO4HbYF/m4UCSJDsGgNqsS1LvdoY= github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U= github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw= @@ -907,6 +914,8 @@ golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10 golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o= +golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -982,6 +991,7 @@ golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -994,6 +1004,9 @@ golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index bef2bb3fa..ebc068ac8 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -132,6 +132,15 @@ type RoomserverInternalAPI interface { response *QueryStateAndAuthChainResponse, ) error + // QueryAuthChain returns the entire auth chain for the event IDs given. + // The response includes the events in the request. + // Omits without error for any missing auth events. There will be no duplicates. + QueryAuthChain( + ctx context.Context, + request *QueryAuthChainRequest, + response *QueryAuthChainResponse, + ) error + // QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from // the response. QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index eb2b2e1d4..c279807e5 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -324,6 +324,16 @@ func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Conte return err } +func (t *RoomserverInternalAPITrace) QueryAuthChain( + ctx context.Context, + request *QueryAuthChainRequest, + response *QueryAuthChainResponse, +) error { + err := t.Impl.QueryAuthChain(ctx, request, response) + util.GetLogger(ctx).WithError(err).Infof("QueryAuthChain req=%+v res=%+v", js(request), js(response)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 55922b7fe..43e562a98 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -275,6 +275,14 @@ type QueryPublishedRoomsResponse struct { RoomIDs []string } +type QueryAuthChainRequest struct { + EventIDs []string +} + +type QueryAuthChainResponse struct { + AuthChain []*gomatrixserverlib.HeaderedEvent +} + type QuerySharedUsersRequest struct { UserID string ExcludeRoomIDs []string diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 79dc2fe14..2bed0c7f4 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -62,10 +62,10 @@ func (w *inputWorker) start() { for { select { case task := <-w.input: - hooks.Run(hooks.KindNewEventReceived, &task.event.Event) + hooks.Run(hooks.KindNewEventReceived, task.event.Event) _, task.err = w.r.processRoomEvent(task.ctx, task.event) if task.err == nil { - hooks.Run(hooks.KindNewEventPersisted, &task.event.Event) + hooks.Run(hooks.KindNewEventPersisted, task.event.Event) } task.wg.Done() case <-time.After(time.Second * 5): diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index c9940c3c2..7346c7a77 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -716,3 +716,16 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID) return nil } + +func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error { + chain, err := getAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs) + if err != nil { + return err + } + hchain := make([]*gomatrixserverlib.HeaderedEvent, len(chain)) + for i := range chain { + hchain[i] = chain[i].Headered(chain[i].Version()) + } + res.AuthChain = hchain + return nil +} diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index e496b81e0..8a1c91d28 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -54,6 +54,7 @@ const ( RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers" RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" + RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" ) type httpRoomserverInternalAPI struct { @@ -502,6 +503,16 @@ func (h *httpRoomserverInternalAPI) QueryKnownUsers( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } +func (h *httpRoomserverInternalAPI) QueryAuthChain( + ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAuthChain") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryAuthChainPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, ) error { diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index ac1fc25b6..f9c8ef9fd 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -452,4 +452,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverQueryAuthChainPath, + httputil.MakeInternalAPI("queryAuthChain", func(req *http.Request) util.JSONResponse { + request := api.QueryAuthChainRequest{} + response := api.QueryAuthChainResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryAuthChain(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 33a65c8f4..95473f97c 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -18,10 +18,13 @@ package msc2836 import ( "bytes" "context" + "crypto/sha256" "encoding/json" "fmt" "io" "net/http" + "sort" + "strings" "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -36,13 +39,12 @@ import ( ) const ( - constRelType = "m.reference" - constRoomIDKey = "relationship_room_id" - constRoomServers = "relationship_servers" + constRelType = "m.reference" ) type EventRelationshipRequest struct { EventID string `json:"event_id"` + RoomID string `json:"room_id"` MaxDepth int `json:"max_depth"` MaxBreadth int `json:"max_breadth"` Limit int `json:"limit"` @@ -52,7 +54,6 @@ type EventRelationshipRequest struct { IncludeChildren bool `json:"include_children"` Direction string `json:"direction"` Batch string `json:"batch"` - AutoJoin bool `json:"auto_join"` } func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) { @@ -81,8 +82,16 @@ type EventRelationshipResponse struct { Limited bool `json:"limited"` } +func toClientResponse(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) *EventRelationshipResponse { + out := &EventRelationshipResponse{ + Events: gomatrixserverlib.ToClientEvents(res.Events, gomatrixserverlib.FormatAll), + Limited: res.Limited, + NextBatch: res.NextBatch, + } + return out +} + // Enable this MSC -// nolint:gocyclo func Enable( base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI, userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, @@ -96,63 +105,22 @@ func Enable( he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) hookErr := db.StoreRelation(context.Background(), he) if hookErr != nil { - util.GetLogger(context.Background()).WithError(hookErr).Error( + util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error( "failed to StoreRelation", ) } - }) - hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { - he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) - ctx := context.Background() - // we only inject metadata for events our server sends - userID := he.Sender() - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return - } - if domain != base.Cfg.Global.ServerName { - return - } - // if this event has an m.relationship, add on the room_id and servers to unsigned - parent, child, relType := parentChildEventIDs(he) - if parent == "" || child == "" || relType == "" { - return - } - event, joinedToRoom := getEventIfVisible(ctx, rsAPI, parent, userID) - if !joinedToRoom { - return - } - err = he.SetUnsignedField(constRoomIDKey, event.RoomID()) - if err != nil { - util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField") - return - } - - var servers []gomatrixserverlib.ServerName - if fsAPI != nil { - var res fs.QueryJoinedHostServerNamesInRoomResponse - err = fsAPI.QueryJoinedHostServerNamesInRoom(ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: event.RoomID(), - }, &res) - if err != nil { - util.GetLogger(context.Background()).WithError(err).Warn("Failed to QueryJoinedHostServerNamesInRoom") - return - } - servers = res.ServerNames - } else { - servers = []gomatrixserverlib.ServerName{ - base.Cfg.Global.ServerName, - } - } - err = he.SetUnsignedField(constRoomServers, servers) - if err != nil { - util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField") - return + // we need to update child metadata here as well as after doing remote /event_relationships requests + // so we catch child metadata originating from /send transactions + hookErr = db.UpdateChildMetadata(context.Background(), he) + if hookErr != nil { + util.GetLogger(context.Background()).WithError(err).WithField("event_id", he.EventID()).Warn( + "failed to update child metadata for event", + ) } }) base.PublicClientAPIMux.Handle("/unstable/event_relationships", - httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI)), + httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI, fsAPI)), ).Methods(http.MethodPost, http.MethodOptions) base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( @@ -163,22 +131,27 @@ func Enable( if fedReq == nil { return errResp } - return federatedEventRelationship(req.Context(), fedReq, db, rsAPI) + return federatedEventRelationship(req.Context(), fedReq, db, rsAPI, fsAPI) }, )).Methods(http.MethodPost, http.MethodOptions) return nil } type reqCtx struct { - ctx context.Context - rsAPI roomserver.RoomserverInternalAPI - db Database - req *EventRelationshipRequest - userID string + ctx context.Context + rsAPI roomserver.RoomserverInternalAPI + db Database + req *EventRelationshipRequest + userID string + roomVersion gomatrixserverlib.RoomVersion + + // federated request args isFederatedRequest bool + serverName gomatrixserverlib.ServerName + fsAPI fs.FederationSenderInternalAPI } -func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { +func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { return func(req *http.Request, device *userapi.Device) util.JSONResponse { relation, err := NewEventRelationshipRequest(req.Body) if err != nil { @@ -193,6 +166,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP req: relation, userID: device.UserID, rsAPI: rsAPI, + fsAPI: fsAPI, isFederatedRequest: false, db: db, } @@ -203,12 +177,14 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP return util.JSONResponse{ Code: 200, - JSON: res, + JSON: toClientResponse(res), } } } -func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI) util.JSONResponse { +func federatedEventRelationship( + ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI, +) util.JSONResponse { relation, err := NewEventRelationshipRequest(bytes.NewBuffer(fedReq.Content())) if err != nil { util.GetLogger(ctx).WithError(err).Error("failed to decode HTTP request as JSON") @@ -218,17 +194,43 @@ func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.F } } rc := reqCtx{ - ctx: ctx, - req: relation, - userID: "", - rsAPI: rsAPI, + ctx: ctx, + req: relation, + rsAPI: rsAPI, + db: db, + // federation args isFederatedRequest: true, - db: db, + fsAPI: fsAPI, + serverName: fedReq.Origin(), } res, resErr := rc.process() if resErr != nil { return *resErr } + // add auth chain information + requiredAuthEventsSet := make(map[string]bool) + var requiredAuthEvents []string + for _, ev := range res.Events { + for _, a := range ev.AuthEventIDs() { + if requiredAuthEventsSet[a] { + continue + } + requiredAuthEvents = append(requiredAuthEvents, a) + requiredAuthEventsSet[a] = true + } + } + var queryRes roomserver.QueryAuthChainResponse + err = rsAPI.QueryAuthChain(ctx, &roomserver.QueryAuthChainRequest{ + EventIDs: requiredAuthEvents, + }, &queryRes) + if err != nil { + // they may already have the auth events so don't fail this request + util.GetLogger(ctx).WithError(err).Error("Failed to QueryAuthChain") + } + res.AuthChain = make([]*gomatrixserverlib.Event, len(queryRes.AuthChain)) + for i := range queryRes.AuthChain { + res.AuthChain[i] = queryRes.AuthChain[i].Unwrap() + } return util.JSONResponse{ Code: 200, @@ -236,18 +238,25 @@ func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.F } } -func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) { - var res EventRelationshipResponse +// nolint:gocyclo +func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) { + var res gomatrixserverlib.MSC2836EventRelationshipsResponse var returnEvents []*gomatrixserverlib.HeaderedEvent // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. - // We should have the event being referenced so don't give any claimed room ID / servers - event := rc.getEventIfVisible(rc.req.EventID, "", nil) + event := rc.getLocalEvent(rc.req.EventID) if event == nil { + event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) + } + if rc.req.RoomID == "" && event != nil { + rc.req.RoomID = event.RoomID() + } + if event == nil || !rc.authorisedToSeeEvent(event) { return nil, &util.JSONResponse{ Code: 403, JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), } } + rc.roomVersion = event.Version() // Retrieve the event. Add it to response array. returnEvents = append(returnEvents, event) @@ -282,29 +291,122 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) { ) returnEvents = append(returnEvents, events...) } - res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents)) + res.Events = make([]*gomatrixserverlib.Event, len(returnEvents)) for i, ev := range returnEvents { - res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll) + // for each event, extract the children_count | hash and add it as unsigned data. + rc.addChildMetadata(ev) + res.Events[i] = ev.Unwrap() } res.Limited = remaining == 0 || walkLimited return &res, nil } +// fetchUnknownEvent retrieves an unknown event from the room specified. This server must +// be joined to the room in question. This has the side effect of injecting surround threaded +// events into the roomserver. +func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.HeaderedEvent { + if rc.isFederatedRequest || roomID == "" { + // we don't do fed hits for fed requests, and we can't ask servers without a room ID! + return nil + } + logger := util.GetLogger(rc.ctx).WithField("room_id", roomID) + // if they supplied a room_id, check the room exists. + var queryVerRes roomserver.QueryRoomVersionForRoomResponse + err := rc.rsAPI.QueryRoomVersionForRoom(rc.ctx, &roomserver.QueryRoomVersionForRoomRequest{ + RoomID: roomID, + }, &queryVerRes) + if err != nil { + logger.WithError(err).Warn("failed to query room version for room, does this room exist?") + return nil + } + + // check the user is joined to that room + var queryMemRes roomserver.QueryMembershipForUserResponse + err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: rc.userID, + }, &queryMemRes) + if err != nil { + logger.WithError(err).Warn("failed to query membership for user in room") + return nil + } + if !queryMemRes.IsInRoom { + return nil + } + + // ask one of the servers in the room for the event + var queryRes fs.QueryJoinedHostServerNamesInRoomResponse + err = rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ + RoomID: roomID, + }, &queryRes) + if err != nil { + logger.WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom") + return nil + } + // query up to 5 servers + serversToQuery := queryRes.ServerNames + if len(serversToQuery) > 5 { + serversToQuery = serversToQuery[:5] + } + + // fetch the event, along with some of the surrounding thread (if it's threaded) and the auth chain. + // Inject the response into the roomserver to remember the event across multiple calls and to set + // unexplored flags correctly. + for _, srv := range serversToQuery { + res, err := rc.MSC2836EventRelationships(eventID, srv, queryVerRes.RoomVersion) + if err != nil { + continue + } + rc.injectResponseToRoomserver(res) + for _, ev := range res.Events { + if ev.EventID() == eventID { + return ev.Headered(ev.Version()) + } + } + } + logger.WithField("servers", serversToQuery).Warn("failed to query event relationships") + return nil +} + // If include_parent: true and there is a valid m.relationship field in the event, // retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. -func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) { - parentID, _, _ := parentChildEventIDs(event) +func (rc *reqCtx) includeParent(childEvent *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) { + parentID, _, _ := parentChildEventIDs(childEvent) if parentID == "" { return nil } - claimedRoomID, claimedServers := roomIDAndServers(event) - return rc.getEventIfVisible(parentID, claimedRoomID, claimedServers) + return rc.lookForEvent(parentID) } // If include_children: true, lookup all events which have event_id as an m.relationship // Apply history visibility checks to all these events and add the ones which pass into the response array, // honouring the recent_first flag and the limit. func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { + if rc.hasUnexploredChildren(parentID) { + // we need to do a remote request to pull in the children as we are missing them locally. + serversToQuery := rc.getServersForEventID(parentID) + var result *gomatrixserverlib.MSC2836EventRelationshipsResponse + for _, srv := range serversToQuery { + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + EventID: parentID, + Direction: "down", + Limit: 100, + MaxBreadth: -1, + MaxDepth: 1, // we just want the children from this parent + RecentFirst: true, + }, rc.roomVersion) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("includeChildren: failed to call MSC2836EventRelationships") + } else { + result = &res + break + } + } + if result != nil { + rc.injectResponseToRoomserver(result) + } + // fallthrough to pull these new events from the DB + } children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent") @@ -313,8 +415,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen } var childEvents []*gomatrixserverlib.HeaderedEvent for _, child := range children { - // in order for us to even know about the children the server must be joined to those rooms, hence pass no claimed room ID or servers. - childEvent := rc.getEventIfVisible(child.EventID, "", nil) + childEvent := rc.lookForEvent(child.EventID) if childEvent != nil { childEvents = append(childEvents, childEvent) } @@ -327,14 +428,9 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen // Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag, // honouring the limit, max_depth and max_breadth values according to the following rules -// nolint: unparam func walkThread( ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, ) ([]*gomatrixserverlib.HeaderedEvent, bool) { - if rc.req.Direction != "down" { - util.GetLogger(ctx).Error("not implemented: direction=up") - return nil, false - } var result []*gomatrixserverlib.HeaderedEvent eventWalker := walker{ ctx: ctx, @@ -352,8 +448,11 @@ func walkThread( } // Process the event. - // TODO: Include edge information: room ID and servers - event := rc.getEventIfVisible(wi.EventID, "", nil) + // if event is not found, use remoteEventRelationships to explore that part of the thread remotely. + // This will probably be easiest if the event relationships response is directly pumped into the database + // so the next walk will do the right thing. This requires those events to be authed and likely injected as + // outliers into the roomserver DB, which will de-dupe appropriately. + event := rc.lookForEvent(wi.EventID) if event != nil { result = append(result, event) } @@ -368,74 +467,280 @@ func walkThread( return result, limited } -func (rc *reqCtx) getEventIfVisible(eventID string, claimedRoomID string, claimedServers []string) *gomatrixserverlib.HeaderedEvent { - event, joinedToRoom := getEventIfVisible(rc.ctx, rc.rsAPI, eventID, rc.userID) - if event != nil && joinedToRoom { - return event +// MSC2836EventRelationships performs an /event_relationships request to a remote server +func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + EventID: eventID, + DepthFirst: rc.req.DepthFirst, + Direction: rc.req.Direction, + Limit: rc.req.Limit, + MaxBreadth: rc.req.MaxBreadth, + MaxDepth: rc.req.MaxDepth, + RecentFirst: rc.req.RecentFirst, + }, ver) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Error("Failed to call MSC2836EventRelationships") + return nil, err } - // either we don't have the event or we aren't joined to the room, regardless we should try joining if auto join is enabled - if !rc.req.AutoJoin { - return nil - } - // if we're doing this on behalf of a random server don't auto-join rooms regardless of what the request says - if rc.isFederatedRequest { - return nil - } - roomID := claimedRoomID - var servers []gomatrixserverlib.ServerName - if event != nil { - roomID = event.RoomID() - } - for _, s := range claimedServers { - servers = append(servers, gomatrixserverlib.ServerName(s)) - } - var joinRes roomserver.PerformJoinResponse - rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{ - UserID: rc.userID, - Content: map[string]interface{}{}, - RoomIDOrAlias: roomID, - ServerNames: servers, - }, &joinRes) - if joinRes.Error != nil { - util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", roomID).Error("Failed to auto-join room") - return nil - } - if event != nil { - return event - } - // TODO: hit /event_relationships on the server we joined via - util.GetLogger(rc.ctx).Infof("joined room but need to fetch event TODO") - return nil + return &res, nil + } -func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) (*gomatrixserverlib.HeaderedEvent, bool) { - var queryEventsRes roomserver.QueryEventsByIDResponse - err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{ - EventIDs: []string{eventID}, - }, &queryEventsRes) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID") - return nil, false +// authorisedToSeeEvent checks that the user or server is allowed to see this event. Returns true if allowed to +// see this request. This only needs to be done once per room at present as we just check for joined status. +func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) bool { + if rc.isFederatedRequest { + // make sure the server is in this room + var res fs.QueryJoinedHostServerNamesInRoomResponse + err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ + RoomID: event.RoomID(), + }, &res) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryJoinedHostServerNamesInRoom") + return false + } + for _, srv := range res.ServerNames { + if srv == rc.serverName { + return true + } + } + return false } - if len(queryEventsRes.Events) == 0 { - util.GetLogger(ctx).Infof("event does not exist") - return nil, false // event does not exist - } - event := queryEventsRes.Events[0] - + // make sure the user is in this room // Allow events if the member is in the room // TODO: This does not honour history_visibility // TODO: This does not honour m.room.create content var queryMembershipRes roomserver.QueryMembershipForUserResponse - err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{ + err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ RoomID: event.RoomID(), - UserID: userID, + UserID: rc.userID, }, &queryMembershipRes) if err != nil { - util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser") - return nil, false + util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser") + return false } - return event, queryMembershipRes.IsInRoom + return queryMembershipRes.IsInRoom +} + +func (rc *reqCtx) getServersForEventID(eventID string) []gomatrixserverlib.ServerName { + if rc.req.RoomID == "" { + util.GetLogger(rc.ctx).WithField("event_id", eventID).Error( + "getServersForEventID: event exists in unknown room", + ) + return nil + } + if rc.roomVersion == "" { + util.GetLogger(rc.ctx).WithField("event_id", eventID).Errorf( + "getServersForEventID: event exists in %s with unknown room version", rc.req.RoomID, + ) + return nil + } + var queryRes fs.QueryJoinedHostServerNamesInRoomResponse + err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ + RoomID: rc.req.RoomID, + }, &queryRes) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Error("getServersForEventID: failed to QueryJoinedHostServerNamesInRoom") + return nil + } + // query up to 5 servers + serversToQuery := queryRes.ServerNames + if len(serversToQuery) > 5 { + serversToQuery = serversToQuery[:5] + } + return serversToQuery +} + +func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MSC2836EventRelationshipsResponse { + if rc.isFederatedRequest { + return nil // we don't query remote servers for remote requests + } + serversToQuery := rc.getServersForEventID(eventID) + var res *gomatrixserverlib.MSC2836EventRelationshipsResponse + var err error + for _, srv := range serversToQuery { + res, err = rc.MSC2836EventRelationships(eventID, srv, rc.roomVersion) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("remoteEventRelationships: failed to call MSC2836EventRelationships") + } else { + break + } + } + return res +} + +// lookForEvent returns the event for the event ID given, by trying to query remote servers +// if the event ID is unknown via /event_relationships. +func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { + event := rc.getLocalEvent(eventID) + if event == nil { + queryRes := rc.remoteEventRelationships(eventID) + if queryRes != nil { + // inject all the events into the roomserver then return the event in question + rc.injectResponseToRoomserver(queryRes) + for _, ev := range queryRes.Events { + if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() { + return ev.Headered(ev.Version()) + } + } + } + } else if rc.hasUnexploredChildren(eventID) { + // we have the local event but we may need to do a remote hit anyway if we are exploring the thread and have unknown children. + // If we don't do this then we risk never fetching the children. + queryRes := rc.remoteEventRelationships(eventID) + if queryRes != nil { + rc.injectResponseToRoomserver(queryRes) + err := rc.db.MarkChildrenExplored(context.Background(), eventID) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Warnf("failed to mark children of %s as explored", eventID) + } + } + } + if rc.req.RoomID == event.RoomID() { + return event + } + return nil +} + +func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent { + var queryEventsRes roomserver.QueryEventsByIDResponse + err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ + EventIDs: []string{eventID}, + }, &queryEventsRes) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Error("getLocalEvent: failed to QueryEventsByID") + return nil + } + if len(queryEventsRes.Events) == 0 { + util.GetLogger(rc.ctx).WithField("event_id", eventID).Infof("getLocalEvent: event does not exist") + return nil // event does not exist + } + return queryEventsRes.Events[0] +} + +// injectResponseToRoomserver injects the events +// into the roomserver as KindOutlier, with auth chains. +func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) { + var stateEvents []*gomatrixserverlib.Event + var messageEvents []*gomatrixserverlib.Event + for _, ev := range res.Events { + if ev.StateKey() != nil { + stateEvents = append(stateEvents, ev) + } else { + messageEvents = append(messageEvents, ev) + } + } + respState := gomatrixserverlib.RespState{ + AuthEvents: res.AuthChain, + StateEvents: stateEvents, + } + eventsInOrder, err := respState.Events() + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse") + return + } + // everything gets sent as an outlier because auth chain events may be disjoint from the DAG + // as may the threaded events. + var ires []roomserver.InputRoomEvent + for _, outlier := range append(eventsInOrder, messageEvents...) { + ires = append(ires, roomserver.InputRoomEvent{ + Kind: roomserver.KindOutlier, + Event: outlier.Headered(outlier.Version()), + AuthEventIDs: outlier.AuthEventIDs(), + }) + } + // we've got the data by this point so use a background context + err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") + } + // update the child count / hash columns for these nodes. We need to do this here because not all events will make it + // through to the KindNewEventPersisted hook because the roomserver will ignore duplicates. Duplicates have meaning though + // as the `unsigned` field may differ (if the number of children changes). + for _, ev := range ires { + err = rc.db.UpdateChildMetadata(context.Background(), ev.Event) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).WithField("event_id", ev.Event.EventID()).Warn("failed to update child metadata for event") + } + } +} + +func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) { + count, hash := rc.getChildMetadata(ev.EventID()) + if count == 0 { + return + } + err := ev.SetUnsignedField("children_hash", gomatrixserverlib.Base64Bytes(hash)) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children_hash") + } + err = ev.SetUnsignedField("children", map[string]int{ + constRelType: count, + }) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children count") + } +} + +func (rc *reqCtx) getChildMetadata(eventID string) (count int, hash []byte) { + children, err := rc.db.ChildrenForParent(rc.ctx, eventID, constRelType, false) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Warn("Failed to get ChildrenForParent for getting child metadata") + return + } + if len(children) == 0 { + return + } + // sort it lexiographically + sort.Slice(children, func(i, j int) bool { + return children[i].EventID < children[j].EventID + }) + // hash it + var eventIDs strings.Builder + for _, c := range children { + _, _ = eventIDs.WriteString(c.EventID) + } + hashValBytes := sha256.Sum256([]byte(eventIDs.String())) + + count = len(children) + hash = hashValBytes[:] + return +} + +// hasUnexploredChildren returns true if this event has unexplored children. +// "An event has unexplored children if the `unsigned` child count on the parent does not match +// how many children the server believes the parent to have. In addition, if the counts match but +// the hashes do not match, then the event is unexplored." +func (rc *reqCtx) hasUnexploredChildren(eventID string) bool { + if rc.isFederatedRequest { + return false // we only explore children for clients, not servers. + } + // extract largest child count from event + eventCount, eventHash, explored, err := rc.db.ChildMetadata(rc.ctx, eventID) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).WithField("event_id", eventID).Warn( + "failed to get ChildMetadata from db", + ) + return false + } + // if there are no recorded children then we know we have >= children. + // if the event has already been explored (read: we hit /event_relationships successfully) + // then don't do it again. We'll only re-do this if we get an even bigger children count, + // see Database.UpdateChildMetadata + if eventCount == 0 || explored { + return false // short-circuit + } + + // calculate child count for event + calcCount, calcHash := rc.getChildMetadata(eventID) + + if eventCount < calcCount { + return false // we have more children + } else if eventCount > calcCount { + return true // the event has more children than we know about + } + // we have the same count, so a mismatched hash means some children are different + return !bytes.Equal(eventHash, calcHash) } type walkInfo struct { @@ -453,9 +758,9 @@ type walker struct { // WalkFrom the event ID given func (w *walker) WalkFrom(eventID string) (limited bool, err error) { - children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst) + children, err := w.childrenForParent(eventID) if err != nil { - util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk") return false, err } var next *walkInfo @@ -467,9 +772,9 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) { return true, nil } // find the children's children - children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst) + children, err = w.childrenForParent(next.EventID) if err != nil { - util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk") return false, err } toWalk = w.addChildren(toWalk, children, next.Depth+1) @@ -528,3 +833,20 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) { child, toWalk = toWalk[0], toWalk[1:] return &child, toWalk } + +// childrenForParent returns the children events for this event ID, honouring the direction: up|down flags +// meaning this can actually be returning the parent for the event instead of the children. +func (w *walker) childrenForParent(eventID string) ([]eventInfo, error) { + if w.req.Direction == "down" { + return w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst) + } + // find the event to pull out the parent + ei, err := w.db.ParentForChild(w.ctx, eventID, constRelType) + if err != nil { + return nil, err + } + if ei != nil { + return []eventInfo{*ei}, nil + } + return nil, nil +} diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 996cc79f0..4eb5708c1 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -4,10 +4,14 @@ import ( "bytes" "context" "crypto/ed25519" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "io/ioutil" "net/http" + "sort" + "strings" "testing" "time" @@ -43,9 +47,7 @@ func TestMSC2836(t *testing.T) { alice := "@alice:localhost" bob := "@bob:localhost" charlie := "@charlie:localhost" - roomIDA := "!alice:localhost" - roomIDB := "!bob:localhost" - roomIDC := "!charlie:localhost" + roomID := "!alice:localhost" // give access tokens to all three users nopUserAPI := &testUserAPI{ accessTokens: make(map[string]userapi.Device), @@ -66,7 +68,7 @@ func TestMSC2836(t *testing.T) { UserID: charlie, } eventA := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDA, + RoomID: roomID, Sender: alice, Type: "m.room.message", Content: map[string]interface{}{ @@ -74,7 +76,7 @@ func TestMSC2836(t *testing.T) { }, }) eventB := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDB, + RoomID: roomID, Sender: bob, Type: "m.room.message", Content: map[string]interface{}{ @@ -86,7 +88,7 @@ func TestMSC2836(t *testing.T) { }, }) eventC := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDB, + RoomID: roomID, Sender: bob, Type: "m.room.message", Content: map[string]interface{}{ @@ -98,7 +100,7 @@ func TestMSC2836(t *testing.T) { }, }) eventD := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDA, + RoomID: roomID, Sender: alice, Type: "m.room.message", Content: map[string]interface{}{ @@ -110,7 +112,7 @@ func TestMSC2836(t *testing.T) { }, }) eventE := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDB, + RoomID: roomID, Sender: bob, Type: "m.room.message", Content: map[string]interface{}{ @@ -122,7 +124,7 @@ func TestMSC2836(t *testing.T) { }, }) eventF := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDC, + RoomID: roomID, Sender: charlie, Type: "m.room.message", Content: map[string]interface{}{ @@ -134,7 +136,7 @@ func TestMSC2836(t *testing.T) { }, }) eventG := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDA, + RoomID: roomID, Sender: alice, Type: "m.room.message", Content: map[string]interface{}{ @@ -146,7 +148,7 @@ func TestMSC2836(t *testing.T) { }, }) eventH := mustCreateEvent(t, fledglingEvent{ - RoomID: roomIDB, + RoomID: roomID, Sender: bob, Type: "m.room.message", Content: map[string]interface{}{ @@ -160,9 +162,9 @@ func TestMSC2836(t *testing.T) { // make everyone joined to each other's rooms nopRsAPI := &testRoomserverAPI{ userToJoinedRooms: map[string][]string{ - alice: []string{roomIDA, roomIDB, roomIDC}, - bob: []string{roomIDA, roomIDB, roomIDC}, - charlie: []string{roomIDA, roomIDB, roomIDC}, + alice: []string{roomID}, + bob: []string{roomID}, + charlie: []string{roomID}, }, events: map[string]*gomatrixserverlib.HeaderedEvent{ eventA.EventID(): eventA, @@ -198,21 +200,6 @@ func TestMSC2836(t *testing.T) { "include_parent": true, })) }) - t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) { - nopUserAPI.accessTokens["frank2"] = userapi.Device{ - AccessToken: "frank2", - DisplayName: "Frank2 Not In Room", - UserID: "@frank2:localhost", - } - // Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB - nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB} - body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{ - "event_id": eventB.EventID(), - "limit": 1, - "include_parent": true, - })) - assertContains(t, body, []string{eventB.EventID()}) - }) t.Run("returns the parent if include_parent is true", func(t *testing.T) { body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ "event_id": eventB.EventID(), @@ -349,6 +336,39 @@ func TestMSC2836(t *testing.T) { })) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()}) }) + t.Run("can navigate up the graph with direction: up", func(t *testing.T) { + // A4 + // | + // B3 + // / \ + // C D2 + // /| \ + // E F1 G + // | + // H + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventF.EventID(), + "recent_first": false, + "depth_first": true, + "direction": "up", + })) + assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()}) + }) + t.Run("includes children and children_hash in unsigned", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 3, + })) + // event B has C,D as children + // event C has no children + // event D has 3 children (not included in response) + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID()}) + assertUnsignedChildren(t, body.Events[0], "m.reference", 2, []string{eventC.EventID(), eventD.EventID()}) + assertUnsignedChildren(t, body.Events[1], "", 0, nil) + assertUnsignedChildren(t, body.Events[2], "m.reference", 3, []string{eventE.EventID(), eventF.EventID(), eventG.EventID()}) + }) } // TODO: TestMSC2836TerminatesLoops (short and long) @@ -411,8 +431,12 @@ func postRelationships(t *testing.T, expectCode int, accessToken string, req *ms } if res.StatusCode == 200 { var result msc2836.EventRelationshipResponse - if err := json.NewDecoder(res.Body).Decode(&result); err != nil { - t.Fatalf("response 200 OK but failed to deserialise JSON : %s", err) + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("response 200 OK but failed to read response body: %s", err) + } + if err := json.Unmarshal(body, &result); err != nil { + t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body)) } return &result } @@ -435,6 +459,43 @@ func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wan } } +func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relType string, wantCount int, childrenEventIDs []string) { + t.Helper() + unsigned := struct { + Children map[string]int `json:"children"` + Hash string `json:"children_hash"` + }{} + if err := json.Unmarshal(ev.Unsigned, &unsigned); err != nil { + if wantCount == 0 { + return // no children so possible there is no unsigned field at all + } + t.Fatalf("Failed to unmarshal unsigned field: %s", err) + } + // zero checks + if wantCount == 0 { + if len(unsigned.Children) != 0 || unsigned.Hash != "" { + t.Fatalf("want 0 children but got unsigned fields %+v", unsigned) + } + return + } + gotCount := unsigned.Children[relType] + if gotCount != wantCount { + t.Errorf("Got %d count, want %d count for rel_type %s", gotCount, wantCount, relType) + } + // work out the hash + sort.Strings(childrenEventIDs) + var b strings.Builder + for _, s := range childrenEventIDs { + b.WriteString(s) + } + t.Logf("hashing %s", b.String()) + hashValBytes := sha256.Sum256([]byte(b.String())) + wantHash := base64.RawStdEncoding.EncodeToString(hashValBytes[:]) + if wantHash != unsigned.Hash { + t.Errorf("Got unsigned hash %s want hash %s", unsigned.Hash, wantHash) + } +} + type testUserAPI struct { accessTokens map[string]userapi.Device } diff --git a/setup/mscs/msc2836/storage.go b/setup/mscs/msc2836/storage.go index 72ea5195b..72523916b 100644 --- a/setup/mscs/msc2836/storage.go +++ b/setup/mscs/msc2836/storage.go @@ -1,20 +1,22 @@ package msc2836 import ( + "bytes" "context" "database/sql" + "encoding/base64" "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) type eventInfo struct { EventID string OriginServerTS gomatrixserverlib.Timestamp RoomID string - Servers []string } type Database interface { @@ -25,6 +27,21 @@ type Database interface { // provided `relType`. The returned slice is sorted by origin_server_ts according to whether // `recentFirst` is true or false. ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) + // ParentForChild returns the parent event for the given child `eventID`. The eventInfo should be nil if + // there is no parent for this child event, with no error. The parent eventInfo can be missing the + // timestamp if the event is not known to the server. + ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) + // UpdateChildMetadata persists the children_count and children_hash from this event if and only if + // the count is greater than what was previously there. If the count is updated, the event will be + // updated to be unexplored. + UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error + // ChildMetadata returns the children_count and children_hash for the event ID in question. + // Also returns the `explored` flag, which is set to true when MarkChildrenExplored is called and is set + // back to `false` when a larger count is inserted via UpdateChildMetadata. + // Returns nil error if the event ID does not exist. + ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error) + // MarkChildrenExplored sets the 'explored' flag on this event to `true`. + MarkChildrenExplored(ctx context.Context, eventID string) error } type DB struct { @@ -34,6 +51,10 @@ type DB struct { insertNodeStmt *sql.Stmt selectChildrenForParentOldestFirstStmt *sql.Stmt selectChildrenForParentRecentFirstStmt *sql.Stmt + selectParentForChildStmt *sql.Stmt + updateChildMetadataStmt *sql.Stmt + selectChildMetadataStmt *sql.Stmt + updateChildMetadataExploredStmt *sql.Stmt } // NewDatabase loads the database for msc2836 @@ -65,19 +86,26 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { CREATE TABLE IF NOT EXISTS msc2836_nodes ( event_id TEXT PRIMARY KEY NOT NULL, origin_server_ts BIGINT NOT NULL, - room_id TEXT NOT NULL + room_id TEXT NOT NULL, + unsigned_children_count BIGINT NOT NULL, + unsigned_children_hash TEXT NOT NULL, + explored SMALLINT NOT NULL ); `) if err != nil { return nil, err } if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING + INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) + VALUES($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING `); err != nil { return nil, err } if d.insertNodeStmt, err = d.db.Prepare(` - INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING + INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored) + VALUES($1, $2, $3, $4, $5, $6) + ON CONFLICT DO NOTHING `); err != nil { return nil, err } @@ -93,6 +121,27 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { return nil, err } + if d.selectParentForChildStmt, err = d.db.Prepare(` + SELECT parent_event_id, parent_room_id FROM msc2836_edges + WHERE child_event_id = $1 AND rel_type = $2 + `); err != nil { + return nil, err + } + if d.updateChildMetadataStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4 + `); err != nil { + return nil, err + } + if d.selectChildMetadataStmt, err = d.db.Prepare(` + SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1 + `); err != nil { + return nil, err + } + if d.updateChildMetadataExploredStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2 + `); err != nil { + return nil, err + } return &d, err } @@ -117,19 +166,26 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { CREATE TABLE IF NOT EXISTS msc2836_nodes ( event_id TEXT PRIMARY KEY NOT NULL, origin_server_ts BIGINT NOT NULL, - room_id TEXT NOT NULL + room_id TEXT NOT NULL, + unsigned_children_count BIGINT NOT NULL, + unsigned_children_hash TEXT NOT NULL, + explored SMALLINT NOT NULL ); `) if err != nil { return nil, err } if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING + INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) + VALUES($1, $2, $3, $4, $5) + ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING `); err != nil { return nil, err } if d.insertNodeStmt, err = d.db.Prepare(` - INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING + INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored) + VALUES($1, $2, $3, $4, $5, $6) + ON CONFLICT DO NOTHING `); err != nil { return nil, err } @@ -145,6 +201,27 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { return nil, err } + if d.selectParentForChildStmt, err = d.db.Prepare(` + SELECT parent_event_id, parent_room_id FROM msc2836_edges + WHERE child_event_id = $1 AND rel_type = $2 + `); err != nil { + return nil, err + } + if d.updateChildMetadataStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4 + `); err != nil { + return nil, err + } + if d.selectChildMetadataStmt, err = d.db.Prepare(` + SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1 + `); err != nil { + return nil, err + } + if d.updateChildMetadataExploredStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2 + `); err != nil { + return nil, err + } return &d, nil } @@ -158,16 +235,55 @@ func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEv if err != nil { return err } + count, hash := extractChildMetadata(ev) return p.writer.Do(p.db, nil, func(txn *sql.Tx) error { _, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON)) if err != nil { return err } - _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID()) + util.GetLogger(ctx).Infof("StoreRelation child=%s parent=%s rel_type=%s", child, parent, relType) + _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID(), count, base64.RawStdEncoding.EncodeToString(hash), 0) return err }) } +func (p *DB) UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { + eventCount, eventHash := extractChildMetadata(ev) + if eventCount == 0 { + return nil // nothing to update with + } + + // extract current children count/hash, if they are less than the current event then update the columns and set to unexplored + count, hash, _, err := p.ChildMetadata(ctx, ev.EventID()) + if err != nil { + return err + } + if eventCount > count || (eventCount == count && !bytes.Equal(hash, eventHash)) { + _, err = p.updateChildMetadataStmt.ExecContext(ctx, eventCount, base64.RawStdEncoding.EncodeToString(eventHash), 0, ev.EventID()) + return err + } + return nil +} + +func (p *DB) ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error) { + var b64hash string + var exploredInt int + if err = p.selectChildMetadataStmt.QueryRowContext(ctx, eventID).Scan(&count, &b64hash, &exploredInt); err != nil { + if err == sql.ErrNoRows { + err = nil + } + return + } + hash, err = base64.RawStdEncoding.DecodeString(b64hash) + explored = exploredInt > 0 + return +} + +func (p *DB) MarkChildrenExplored(ctx context.Context, eventID string) error { + _, err := p.updateChildMetadataExploredStmt.ExecContext(ctx, 1, eventID) + return err +} + func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) { var rows *sql.Rows var err error @@ -191,6 +307,17 @@ func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, rec return children, nil } +func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) { + var ei eventInfo + err := p.selectParentForChildStmt.QueryRowContext(ctx, eventID, relType).Scan(&ei.EventID, &ei.RoomID) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + return &ei, nil +} + func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) { if ev == nil { return @@ -224,3 +351,19 @@ func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, serve } return body.RoomID, body.Servers } + +func extractChildMetadata(ev *gomatrixserverlib.HeaderedEvent) (count int, hash []byte) { + unsigned := struct { + Counts map[string]int `json:"children"` + Hash gomatrixserverlib.Base64Bytes `json:"children_hash"` + }{} + if err := json.Unmarshal(ev.Unsigned(), &unsigned); err != nil { + // expected if there is no unsigned field at all + return + } + for _, c := range unsigned.Counts { + count += c + } + hash = unsigned.Hash + return +} diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index 8b0498ced..a8e5668ea 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -16,15 +16,18 @@ package mscs import ( + "context" "fmt" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/mscs/msc2836" + "github.com/matrix-org/util" ) // Enable MSCs - returns an error on unknown MSCs func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error { for _, msc := range base.Cfg.MSCs.MSCs { + util.GetLogger(context.Background()).WithField("msc", msc).Info("Enabling MSC") if err := EnableMSC(base, monolith, msc); err != nil { return err }