Merge branch 'master' into matthew/peeking-over-fed

This commit is contained in:
Neil Alexander 2020-12-10 10:34:14 +00:00 committed by GitHub
commit 45f0fdd7ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 1441 additions and 609 deletions

View file

@ -1,5 +1,20 @@
# Changelog # Changelog
## Dendrite 0.3.3 (2020-12-09)
### Features
* Federation sender should now use considerably less CPU cycles and RAM when sending events into large rooms
* The roomserver now uses considerably less CPU cycles by not calculating event IDs so often
* Experimental support for [MSC2836](https://github.com/matrix-org/matrix-doc/pull/2836) (threading) has been merged
* Dendrite will no longer hold federation HTTP connections open unnecessarily, which should help to reduce ambient CPU/RAM usage and hold fewer long-term file descriptors
### Fixes
* A bug in the latest event updater has been fixed, which should prevent the roomserver from losing forward extremities in some rare cases
* A panic has been fixed when federation is disabled (contributed by [kraem](https://github.com/kraem))
* The response format of the `/joined_members` endpoint has been fixed (contributed by [alexkursell](https://github.com/alexkursell))
## Dendrite 0.3.2 (2020-12-02) ## Dendrite 0.3.2 (2020-12-02)
### Features ### Features

View file

@ -8,7 +8,6 @@ It intends to provide an **efficient**, **reliable** and **scalable** alternativ
a [brand new Go test suite](https://github.com/matrix-org/complement). a [brand new Go test suite](https://github.com/matrix-org/complement).
- Scalable: can run on multiple machines and eventually scale to massive homeserver deployments. - Scalable: can run on multiple machines and eventually scale to massive homeserver deployments.
As of October 2020, Dendrite has now entered **beta** which means: As of October 2020, Dendrite has now entered **beta** which means:
- Dendrite is ready for early adopters. We recommend running in Monolith mode with a PostgreSQL database. - Dendrite is ready for early adopters. We recommend running in Monolith mode with a PostgreSQL database.
- Dendrite has periodic semver releases. We intend to release new versions as we land significant features. - Dendrite has periodic semver releases. We intend to release new versions as we land significant features.
@ -24,7 +23,7 @@ This does not mean:
Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices. Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices.
In the future, we will be able to scale up to gigantic servers (equivalent to matrix.org) via polylith mode. In the future, we will be able to scale up to gigantic servers (equivalent to matrix.org) via polylith mode.
Join us in: If you have further questions, please take a look at [our FAQ](docs/FAQ.md) or join us in:
- **[#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org)** - General chat about the Dendrite project, for users and server admins alike - **[#dendrite:matrix.org](https://matrix.to/#/#dendrite:matrix.org)** - General chat about the Dendrite project, for users and server admins alike
- **[#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org)** - The place for developers, where all Dendrite development discussion happens - **[#dendrite-dev:matrix.org](https://matrix.to/#/#dendrite-dev:matrix.org)** - The place for developers, where all Dendrite development discussion happens

View file

@ -8,7 +8,7 @@
# - `DENDRITE_LINT_CONCURRENCY` - number of concurrent linters to run, # - `DENDRITE_LINT_CONCURRENCY` - number of concurrent linters to run,
# golangci-lint defaults this to NumCPU # golangci-lint defaults this to NumCPU
# - `GOGC` - how often to perform garbage collection during golangci-lint runs. # - `GOGC` - how often to perform garbage collection during golangci-lint runs.
# Essentially a ratio of memory/speed. See https://github.com/golangci/golangci-lint#memory-usage-of-golangci-lint # Essentially a ratio of memory/speed. See https://golangci-lint.run/usage/performance/#memory-usage
# for more info. # for more info.

View file

@ -36,6 +36,7 @@ import (
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs"
"github.com/matrix-org/dendrite/signingkeyserver" "github.com/matrix-org/dendrite/signingkeyserver"
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -130,6 +131,8 @@ func main() {
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-e2ekey.db", *instanceName)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-e2ekey.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
if err = cfg.Derive(); err != nil { if err = cfg.Derive(); err != nil {
panic(err) panic(err)
} }
@ -190,6 +193,9 @@ func main() {
base.Base.PublicKeyAPIMux, base.Base.PublicKeyAPIMux,
base.Base.PublicMediaAPIMux, base.Base.PublicMediaAPIMux,
) )
if err := mscs.Enable(&base.Base, &monolith); err != nil {
logrus.WithError(err).Fatalf("Failed to enable MSCs")
}
httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.Base.InternalAPIMux) httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.Base.InternalAPIMux)

View file

@ -39,6 +39,7 @@ import (
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs"
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -83,6 +84,8 @@ func main() {
cfg.FederationSender.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName)) cfg.FederationSender.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
if err = cfg.Derive(); err != nil { if err = cfg.Derive(); err != nil {
panic(err) panic(err)
} }
@ -151,6 +154,9 @@ func main() {
base.PublicKeyAPIMux, base.PublicKeyAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
) )
if err := mscs.Enable(base, &monolith); err != nil {
logrus.WithError(err).Fatalf("Failed to enable MSCs")
}
httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux)

View file

@ -63,6 +63,8 @@ func main() {
if *defaultsForCI { if *defaultsForCI {
cfg.ClientAPI.RateLimiting.Enabled = false cfg.ClientAPI.RateLimiting.Enabled = false
cfg.FederationSender.DisableTLSValidation = true cfg.FederationSender.DisableTLSValidation = true
cfg.MSCs.MSCs = []string{"msc2836"}
cfg.Logging[0].Level = "trace"
} }
j, err := yaml.Marshal(cfg) j, err := yaml.Marshal(cfg)

View file

@ -89,7 +89,7 @@ global:
# Naffka database options. Not required when using Kafka. # Naffka database options. Not required when using Kafka.
naffka_database: naffka_database:
connection_string: file:naffka.db connection_string: file:naffka.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -110,7 +110,7 @@ app_service_api:
connect: http://localhost:7777 connect: http://localhost:7777
database: database:
connection_string: file:appservice.db connection_string: file:appservice.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -185,7 +185,7 @@ federation_sender:
connect: http://localhost:7775 connect: http://localhost:7775
database: database:
connection_string: file:federationsender.db connection_string: file:federationsender.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -211,7 +211,7 @@ key_server:
connect: http://localhost:7779 connect: http://localhost:7779
database: database:
connection_string: file:keyserver.db connection_string: file:keyserver.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -224,7 +224,7 @@ media_api:
listen: http://[::]:8074 listen: http://[::]:8074
database: database:
connection_string: file:mediaapi.db connection_string: file:mediaapi.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -260,7 +260,7 @@ room_server:
connect: http://localhost:7770 connect: http://localhost:7770
database: database:
connection_string: file:roomserver.db connection_string: file:roomserver.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -271,7 +271,7 @@ signing_key_server:
connect: http://localhost:7780 connect: http://localhost:7780
database: database:
connection_string: file:signingkeyserver.db connection_string: file:signingkeyserver.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -300,7 +300,7 @@ sync_api:
listen: http://[::]:8073 listen: http://[::]:8073
database: database:
connection_string: file:syncapi.db connection_string: file:syncapi.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -316,12 +316,12 @@ user_api:
connect: http://localhost:7781 connect: http://localhost:7781
account_database: account_database:
connection_string: file:userapi_accounts.db connection_string: file:userapi_accounts.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database: device_database:
connection_string: file:userapi_devices.db connection_string: file:userapi_devices.db
max_open_conns: 100 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -2,13 +2,13 @@
In addition to standard Go code style (`gofmt`, `goimports`), we use `golangci-lint` In addition to standard Go code style (`gofmt`, `goimports`), we use `golangci-lint`
to run a number of linters, the exact list can be found under linters in [.golangci.yml](.golangci.yml). to run a number of linters, the exact list can be found under linters in [.golangci.yml](.golangci.yml).
[Installation](https://github.com/golangci/golangci-lint#install) and [Editor [Installation](https://github.com/golangci/golangci-lint#install-golangci-lint) and [Editor
Integration](https://github.com/golangci/golangci-lint#editor-integration) for Integration](https://golangci-lint.run/usage/integrations/#editor-integration) for
it can be found in the readme of golangci-lint. it can be found in the readme of golangci-lint.
For rare cases where a linter is giving a spurious warning, it can be disabled For rare cases where a linter is giving a spurious warning, it can be disabled
for that line or statement using a [comment for that line or statement using a [comment
directive](https://github.com/golangci/golangci-lint#nolint), e.g. `var directive](https://golangci-lint.run/usage/false-positives/#nolint), e.g. `var
bad_name int //nolint:golint,unused`. This should be used sparingly and only bad_name int //nolint:golint,unused`. This should be used sparingly and only
when its clear that the lint warning is spurious. when its clear that the lint warning is spurious.

View file

@ -37,7 +37,7 @@ If a job fails, click the "details" button and you should be taken to the job's
logs. logs.
![Click the details button on the failing build ![Click the details button on the failing build
step](docs/images/details-button-location.jpg) step](https://raw.githubusercontent.com/matrix-org/dendrite/master/docs/images/details-button-location.jpg)
Scroll down to the failing step and you should see some log output. Scan the Scroll down to the failing step and you should see some log output. Scan the
logs until you find what it's complaining about, fix it, submit a new commit, logs until you find what it's complaining about, fix it, submit a new commit,

View file

@ -12,6 +12,10 @@ No, although a good portion of the Matrix specification has been implemented. Mo
No, not at present. There will be in the future when Dendrite reaches version 1.0. No, not at present. There will be in the future when Dendrite reaches version 1.0.
### Should I run a monolith or a polylith deployment?
Monolith deployments are always preferred where possible, and at this time, are far better tested than polylith deployments are. The only reason to consider a polylith deployment is if you wish to run different Dendrite components on separate physical machines.
### I've installed Dendrite but federation isn't working ### I've installed Dendrite but federation isn't working
Check the [Federation Tester](https://federationtester.matrix.org). You need at least: Check the [Federation Tester](https://federationtester.matrix.org). You need at least:

View file

@ -9,7 +9,6 @@ import (
"time" "time"
eduAPI "github.com/matrix-org/dendrite/eduserver/api" eduAPI "github.com/matrix-org/dendrite/eduserver/api"
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -85,6 +84,7 @@ func (o *testEDUProducer) InputReceiptEvent(
} }
type testRoomserverAPI struct { type testRoomserverAPI struct {
api.RoomserverInternalAPITrace
inputRoomEvents []api.InputRoomEvent inputRoomEvents []api.InputRoomEvent
queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse
queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse
@ -92,12 +92,6 @@ type testRoomserverAPI struct {
queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse
} }
func (t *testRoomserverAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, resp *api.PerformForgetResponse) error {
return nil
}
func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {}
func (t *testRoomserverAPI) InputRoomEvents( func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context, ctx context.Context,
request *api.InputRoomEventsRequest, request *api.InputRoomEventsRequest,
@ -109,58 +103,6 @@ func (t *testRoomserverAPI) InputRoomEvents(
} }
} }
func (t *testRoomserverAPI) PerformInvite(
ctx context.Context,
req *api.PerformInviteRequest,
res *api.PerformInviteResponse,
) error {
return nil
}
func (t *testRoomserverAPI) PerformJoin(
ctx context.Context,
req *api.PerformJoinRequest,
res *api.PerformJoinResponse,
) {
}
func (t *testRoomserverAPI) PerformPeek(
ctx context.Context,
req *api.PerformPeekRequest,
res *api.PerformPeekResponse,
) {
}
func (t *testRoomserverAPI) PerformUnpeek(
ctx context.Context,
req *api.PerformUnpeekRequest,
res *api.PerformUnpeekResponse,
) {
}
func (t *testRoomserverAPI) PerformPublish(
ctx context.Context,
req *api.PerformPublishRequest,
res *api.PerformPublishResponse,
) {
}
func (t *testRoomserverAPI) PerformLeave(
ctx context.Context,
req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse,
) error {
return nil
}
func (t *testRoomserverAPI) PerformInboundPeek(
ctx context.Context,
req *api.PerformInboundPeekRequest,
res *api.PerformInboundPeekResponse,
) error {
return nil
}
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.
func (t *testRoomserverAPI) QueryLatestEventsAndState( func (t *testRoomserverAPI) QueryLatestEventsAndState(
ctx context.Context, ctx context.Context,

View file

@ -21,6 +21,7 @@ type FederationClient interface {
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
} }

View file

@ -94,6 +94,12 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
} }
if err := s.processMessage(*output.NewRoomEvent); err != nil { if err := s.processMessage(*output.NewRoomEvent); err != nil {
switch err.(type) {
case *queue.ErrorFederationDisabled:
log.WithField("error", output.Type).Info(
err.Error(),
)
default:
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event": string(ev.JSON()), "event": string(ev.JSON()),
@ -101,6 +107,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
"del": output.NewRoomEvent.RemovesStateEventIDs, "del": output.NewRoomEvent.RemovesStateEventIDs,
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: write room event failure") }).Panicf("roomserver output log: write room event failure")
}
return nil return nil
} }
case api.OutputTypeNewInboundPeek: case api.OutputTypeNewInboundPeek:

View file

@ -46,7 +46,7 @@ func NewInternalAPI(
) api.FederationSenderInternalAPI { ) api.FederationSenderInternalAPI {
cfg := &base.Cfg.FederationSender cfg := &base.Cfg.FederationSender
federationSenderDB, err := storage.NewDatabase(&cfg.Database) federationSenderDB, err := storage.NewDatabase(&cfg.Database, base.Caches)
if err != nil { if err != nil {
logrus.WithError(err).Panic("failed to connect to federation sender db") logrus.WithError(err).Panic("failed to connect to federation sender db")
} }

View file

@ -229,3 +229,18 @@ func (a *FederationSenderInternalAPI) LookupServerKeys(
} }
return ires.([]gomatrixserverlib.ServerKeys), nil return ires.([]gomatrixserverlib.ServerKeys), nil
} }
func (a *FederationSenderInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion)
})
if err != nil {
return res, err
}
return ires.(gomatrixserverlib.MSC2836EventRelationshipsResponse), nil
}

View file

@ -33,6 +33,7 @@ const (
FederationSenderGetEventPath = "/federationsender/client/getEvent" FederationSenderGetEventPath = "/federationsender/client/getEvent"
FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys" FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys"
FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys" FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys"
FederationSenderEventRelationshipsPath = "/federationsender/client/msc2836eventRelationships"
) )
// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API.
@ -430,3 +431,35 @@ func (h *httpFederationSenderInternalAPI) LookupServerKeys(
} }
return response.ServerKeys, nil return response.ServerKeys, nil
} }
type eventRelationships struct {
S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion
Res gomatrixserverlib.MSC2836EventRelationshipsResponse
Err *api.FederationClientError
}
func (h *httpFederationSenderInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships")
defer span.Finish()
request := eventRelationships{
S: s,
Req: r,
RoomVer: roomVersion,
}
var response eventRelationships
apiURL := h.federationSenderURL + FederationSenderEventRelationshipsPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
}

View file

@ -307,4 +307,26 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route
return util.JSONResponse{Code: http.StatusOK, JSON: request} return util.JSONResponse{Code: http.StatusOK, JSON: request}
}), }),
) )
internalAPIMux.Handle(
FederationSenderEventRelationshipsPath,
httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse {
var request eventRelationships
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
} }

View file

@ -35,6 +35,8 @@ import (
const ( const (
maxPDUsPerTransaction = 50 maxPDUsPerTransaction = 50
maxEDUsPerTransaction = 50 maxEDUsPerTransaction = 50
maxPDUsInMemory = 128
maxEDUsInMemory = 128
queueIdleTimeout = time.Second * 30 queueIdleTimeout = time.Second * 30
) )
@ -51,54 +53,56 @@ type destinationQueue struct {
destination gomatrixserverlib.ServerName // destination of requests destination gomatrixserverlib.ServerName // destination of requests
running atomic.Bool // is the queue worker running? running atomic.Bool // is the queue worker running?
backingOff atomic.Bool // true if we're backing off backingOff atomic.Bool // true if we're backing off
overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more
statistics *statistics.ServerStatistics // statistics about this remote server statistics *statistics.ServerStatistics // statistics about this remote server
transactionIDMutex sync.Mutex // protects transactionID transactionIDMutex sync.Mutex // protects transactionID
transactionID gomatrixserverlib.TransactionID // last transaction ID transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful
transactionCount atomic.Int32 // how many events in this transaction so far notify chan struct{} // interrupts idle wait pending PDUs/EDUs
notifyPDUs chan bool // interrupts idle wait for PDUs pendingPDUs []*queuedPDU // PDUs waiting to be sent
notifyEDUs chan bool // interrupts idle wait for EDUs pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
interruptBackoff chan bool // interrupts backoff interruptBackoff chan bool // interrupts backoff
} }
// Send event adds the event to the pending queue for the destination. // Send event adds the event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to // If the queue is empty then it starts a background goroutine to
// start sending events to that destination. // start sending events to that destination.
func (oq *destinationQueue) sendEvent(receipt *shared.Receipt) { func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) {
// Create a transaction ID. We'll either do this if we don't have if event == nil {
// one made up yet, or if we've exceeded the number of maximum log.Errorf("attempt to send nil PDU with destination %q", oq.destination)
// events allowed in a single tranaction. We'll reset the counter return
// when we do.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" || oq.transactionCount.Load() >= maxPDUsPerTransaction {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
oq.transactionCount.Store(0)
} }
oq.transactionIDMutex.Unlock()
// Create a database entry that associates the given PDU NID with // Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
if err := oq.db.AssociatePDUWithDestination( if err := oq.db.AssociatePDUWithDestination(
context.TODO(), context.TODO(),
oq.transactionID, // the current transaction ID "", // TODO: remove this, as we don't need to persist the transaction ID
oq.destination, // the destination server name oq.destination, // the destination server name
receipt, // NIDs from federationsender_queue_json table receipt, // NIDs from federationsender_queue_json table
); err != nil { ); err != nil {
log.WithError(err).Errorf("failed to associate PDU receipt %q with destination %q", receipt.String(), oq.destination) log.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination)
return return
} }
// We've successfully added a PDU to the transaction so increase
// the counter.
oq.transactionCount.Add(1)
// Check if the destination is blacklisted. If it isn't then wake // Check if the destination is blacklisted. If it isn't then wake
// up the queue. // up the queue.
if !oq.statistics.Blacklisted() { if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingPDUs) < maxPDUsInMemory {
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
pdu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep. // Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded() oq.wakeQueueIfNeeded()
// If we're blocking on waiting PDUs then tell the queue that we
// have work to do.
select { select {
case oq.notifyPDUs <- true: case oq.notify <- struct{}{}:
default: default:
} }
} }
@ -107,7 +111,11 @@ func (oq *destinationQueue) sendEvent(receipt *shared.Receipt) {
// sendEDU adds the EDU event to the pending queue for the destination. // sendEDU adds the EDU event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to // If the queue is empty then it starts a background goroutine to
// start sending events to that destination. // start sending events to that destination.
func (oq *destinationQueue) sendEDU(receipt *shared.Receipt) { func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) {
if event == nil {
log.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return
}
// Create a database entry that associates the given PDU NID with // Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
@ -116,21 +124,28 @@ func (oq *destinationQueue) sendEDU(receipt *shared.Receipt) {
oq.destination, // the destination server name oq.destination, // the destination server name
receipt, // NIDs from federationsender_queue_json table receipt, // NIDs from federationsender_queue_json table
); err != nil { ); err != nil {
log.WithError(err).Errorf("failed to associate EDU receipt %q with destination %q", receipt.String(), oq.destination) log.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return return
} }
// We've successfully added an EDU to the transaction so increase
// the counter.
oq.transactionCount.Add(1)
// Check if the destination is blacklisted. If it isn't then wake // Check if the destination is blacklisted. If it isn't then wake
// up the queue. // up the queue.
if !oq.statistics.Blacklisted() { if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingEDUs) < maxEDUsInMemory {
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
edu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep. // Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded() oq.wakeQueueIfNeeded()
// If we're blocking on waiting EDUs then tell the queue that we
// have work to do.
select { select {
case oq.notifyEDUs <- true: case oq.notify <- struct{}{}:
default: default:
} }
} }
@ -152,48 +167,71 @@ func (oq *destinationQueue) wakeQueueIfNeeded() {
} }
} }
// waitForPDUs returns a channel for pending PDUs, which will be // getPendingFromDatabase will look at the database and see if
// used in backgroundSend select. It returns a closed channel if // there are any persisted events that haven't been sent to this
// there is something pending right now, or an open channel if // destination yet. If so, they will be queued up.
// we're waiting for something. // nolint:gocyclo
func (oq *destinationQueue) waitForPDUs() chan bool { func (oq *destinationQueue) getPendingFromDatabase() {
pendingPDUs, err := oq.db.GetPendingPDUCount(context.TODO(), oq.destination) // Check to see if there's anything to do for this server
if err != nil { // in the database.
log.WithError(err).Errorf("Failed to get pending PDU count on queue %q", oq.destination) retrieved := false
} ctx := context.Background()
// If there are PDUs pending right now then we'll return a closed oq.pendingMutex.Lock()
// channel. This will mean that the backgroundSend will not block. defer oq.pendingMutex.Unlock()
if pendingPDUs > 0 {
ch := make(chan bool, 1)
close(ch)
return ch
}
// If there are no PDUs pending right now then instead we'll return
// the notify channel, so that backgroundSend can pick up normal
// notifications from sendEvent.
return oq.notifyPDUs
}
// waitForEDUs returns a channel for pending EDUs, which will be // Take a note of all of the PDUs and EDUs that we already
// used in backgroundSend select. It returns a closed channel if // have cached. We will index them based on the receipt,
// there is something pending right now, or an open channel if // which ultimately just contains the index of the PDU/EDU
// we're waiting for something. // in the database.
func (oq *destinationQueue) waitForEDUs() chan bool { gotPDUs := map[string]struct{}{}
pendingEDUs, err := oq.db.GetPendingEDUCount(context.TODO(), oq.destination) gotEDUs := map[string]struct{}{}
if err != nil { for _, pdu := range oq.pendingPDUs {
log.WithError(err).Errorf("Failed to get pending EDU count on queue %q", oq.destination) gotPDUs[pdu.receipt.String()] = struct{}{}
}
for _, edu := range oq.pendingEDUs {
gotEDUs[edu.receipt.String()] = struct{}{}
}
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
// We have room in memory for some PDUs - let's request no more than that.
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil {
for receipt, pdu := range pdus {
if _, ok := gotPDUs[receipt.String()]; ok {
continue
}
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
retrieved = true
}
} else {
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
}
}
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
// We have room in memory for some EDUs - let's request no more than that.
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil {
for receipt, edu := range edus {
if _, ok := gotEDUs[receipt.String()]; ok {
continue
}
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
retrieved = true
}
} else {
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
}
}
// If we've retrieved all of the events from the database with room to spare
// in memory then we'll no longer consider this queue to be overflowed.
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory {
oq.overflowed.Store(false)
}
// If we've retrieved some events then notify the destination queue goroutine.
if retrieved {
select {
case oq.notify <- struct{}{}:
default:
} }
// If there are EDUs pending right now then we'll return a closed
// channel. This will mean that the backgroundSend will not block.
if pendingEDUs > 0 {
ch := make(chan bool, 1)
close(ch)
return ch
} }
// If there are no EDUs pending right now then instead we'll return
// the notify channel, so that backgroundSend can pick up normal
// notifications from sendEvent.
return oq.notifyEDUs
} }
// backgroundSend is the worker goroutine for sending events. // backgroundSend is the worker goroutine for sending events.
@ -206,25 +244,28 @@ func (oq *destinationQueue) backgroundSend() {
} }
defer oq.running.Store(false) defer oq.running.Store(false)
// Mark the queue as overflowed, so we will consult the database
// to see if there's anything new to send.
oq.overflowed.Store(true)
for { for {
pendingPDUs, pendingEDUs := false, false // If we are overflowing memory and have sent things out to the
// database then we can look up what those things are.
if oq.overflowed.Load() {
oq.getPendingFromDatabase()
}
// If we have nothing to do then wait either for incoming events, or // If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout. // until we hit an idle timeout.
select { select {
case <-oq.waitForPDUs(): case <-oq.notify:
// We were woken up because there are new PDUs waiting in the // There's work to do, either because getPendingFromDatabase
// database. // told us there is, or because a new event has come in via
pendingPDUs = true // sendEvent/sendEDU.
case <-oq.waitForEDUs():
// We were woken up because there are new PDUs waiting in the
// database.
pendingEDUs = true
case <-time.After(queueIdleTimeout): case <-time.After(queueIdleTimeout):
// The worker is idle so stop the goroutine. It'll get // The worker is idle so stop the goroutine. It'll get
// restarted automatically the next time we have an event to // restarted automatically the next time we have an event to
// send. // send.
log.Tracef("Queue %q has been idle for %s, going to sleep", oq.destination, queueIdleTimeout)
return return
} }
@ -237,6 +278,16 @@ func (oq *destinationQueue) backgroundSend() {
// has exceeded a maximum allowable value. Clean up the in-memory // has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer. // buffers at this point. The PDU clean-up is already on a defer.
log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
return return
} }
if until != nil && until.After(time.Now()) { if until != nil && until.After(time.Now()) {
@ -250,18 +301,41 @@ func (oq *destinationQueue) backgroundSend() {
} }
} }
// Work out which PDUs/EDUs to include in the next transaction.
oq.pendingMutex.RLock()
pduCount := len(oq.pendingPDUs)
eduCount := len(oq.pendingEDUs)
if pduCount > maxPDUsPerTransaction {
pduCount = maxPDUsPerTransaction
}
if eduCount > maxEDUsPerTransaction {
eduCount = maxEDUsPerTransaction
}
toSendPDUs := oq.pendingPDUs[:pduCount]
toSendEDUs := oq.pendingEDUs[:eduCount]
oq.pendingMutex.RUnlock()
// If we have pending PDUs or EDUs then construct a transaction. // If we have pending PDUs or EDUs then construct a transaction.
if pendingPDUs || pendingEDUs {
// Try sending the next transaction and see what happens. // Try sending the next transaction and see what happens.
transaction, terr := oq.nextTransaction() transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
if terr != nil { if terr != nil {
// We failed to send the transaction. Mark it as a failure. // We failed to send the transaction. Mark it as a failure.
oq.statistics.Failure() oq.statistics.Failure()
} else if transaction { } else if transaction {
// If we successfully sent the transaction then clear out // If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID. // the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success() oq.statistics.Success()
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs[:pc] {
oq.pendingPDUs[i] = nil
} }
for i := range oq.pendingEDUs[:ec] {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = oq.pendingPDUs[pc:]
oq.pendingEDUs = oq.pendingEDUs[ec:]
oq.pendingMutex.Unlock()
} }
} }
} }
@ -270,16 +344,20 @@ func (oq *destinationQueue) backgroundSend() {
// queue and sends it. Returns true if a transaction was sent or // queue and sends it. Returns true if a transaction was sent or
// false otherwise. // false otherwise.
// nolint:gocyclo // nolint:gocyclo
func (oq *destinationQueue) nextTransaction() (bool, error) { func (oq *destinationQueue) nextTransaction(
// Before we do anything, we need to roll over the transaction pdus []*queuedPDU,
// ID that is being used to coalesce events into the next TX. edus []*queuedEDU,
// Otherwise it's possible that we'll pick up an incomplete ) (bool, int, int, error) {
// transaction and end up nuking the rest of the events at the // If there's no projected transaction ID then generate one. If
// cleanup stage. // the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock() oq.transactionIDMutex.Lock()
oq.transactionID = "" if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock() oq.transactionIDMutex.Unlock()
oq.transactionCount.Store(0)
// Create the transaction. // Create the transaction.
t := gomatrixserverlib.Transaction{ t := gomatrixserverlib.Transaction{
@ -289,58 +367,36 @@ func (oq *destinationQueue) nextTransaction() (bool, error) {
t.Origin = oq.origin t.Origin = oq.origin
t.Destination = oq.destination t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
// Ask the database for any pending PDUs from the next transaction.
// maxPDUsPerTransaction is an upper limit but we probably won't
// actually retrieve that many events.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
txid, pdus, pduReceipt, err := oq.db.GetNextTransactionPDUs(
ctx, // context
oq.destination, // server name
maxPDUsPerTransaction, // max events to retrieve
)
if err != nil {
log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination)
return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err)
}
edus, eduReceipt, err := oq.db.GetNextTransactionEDUs(
ctx, // context
oq.destination, // server name
maxEDUsPerTransaction, // max events to retrieve
)
if err != nil {
log.WithError(err).Errorf("failed to get next transaction EDUs for server %q", oq.destination)
return false, fmt.Errorf("oq.db.GetNextTransactionEDUs: %w", err)
}
// If we didn't get anything from the database and there are no // If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here. // pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(edus) == 0 { if len(pdus) == 0 && len(edus) == 0 {
return false, nil return false, 0, 0, nil
} }
// Pick out the transaction ID from the database. If we didn't var pduReceipts []*shared.Receipt
// get a transaction ID (i.e. because there are no PDUs but only var eduReceipts []*shared.Receipt
// EDUs) then generate a transaction ID.
t.TransactionID = txid
if t.TransactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
// Go through PDUs that we retrieved from the database, if any, // Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction. // and add them into the transaction.
for _, pdu := range pdus { for _, pdu := range pdus {
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the // Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct // gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, (*pdu).JSON()) t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
} }
// Do the same for pending EDUS in the queue. // Do the same for pending EDUS in the queue.
for _, edu := range edus { for _, edu := range edus {
t.EDUs = append(t.EDUs, *edu) if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
} }
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
@ -349,34 +405,38 @@ func (oq *destinationQueue) nextTransaction() (bool, error) {
// TODO: we should check for 500-ish fails vs 400-ish here, // TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response // since we shouldn't queue things indefinitely in response
// to a 400-ish error // to a 400-ish error
ctx, cancel = context.WithTimeout(context.Background(), time.Minute*5) ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel() defer cancel()
_, err = oq.client.SendTransaction(ctx, t) _, err := oq.client.SendTransaction(ctx, t)
switch err.(type) { switch err.(type) {
case nil: case nil:
// Clean up the transaction in the database. // Clean up the transaction in the database.
if pduReceipt != nil { if pduReceipts != nil {
//logrus.Infof("Cleaning PDUs %q", pduReceipt.String()) //logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipt); err != nil { if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipts); err != nil {
log.WithError(err).Errorf("failed to clean PDUs %q for server %q", pduReceipt.String(), t.Destination) log.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
} }
} }
if eduReceipt != nil { if eduReceipts != nil {
//logrus.Infof("Cleaning EDUs %q", eduReceipt.String()) //logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipt); err != nil { if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipts); err != nil {
log.WithError(err).Errorf("failed to clean EDUs %q for server %q", eduReceipt.String(), t.Destination) log.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
} }
} }
return true, nil // Reset the transaction ID.
oq.transactionIDMutex.Lock()
oq.transactionID = ""
oq.transactionIDMutex.Unlock()
return true, len(t.PDUs), len(t.EDUs), nil
case gomatrix.HTTPError: case gomatrix.HTTPError:
// Report that we failed to send the transaction and we // Report that we failed to send the transaction and we
// will retry again, subject to backoff. // will retry again, subject to backoff.
return false, err return false, 0, 0, err
default: default:
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"destination": oq.destination, "destination": oq.destination,
log.ErrorKey: err, log.ErrorKey: err,
}).Info("problem sending transaction") }).Infof("Failed to send transaction %q", t.TransactionID)
return false, err return false, 0, 0, err
} }
} }

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/statistics"
"github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -83,8 +84,8 @@ func NewOutgoingQueues(
log.WithError(err).Error("Failed to get EDU server names for destination queue hydration") log.WithError(err).Error("Failed to get EDU server names for destination queue hydration")
} }
for serverName := range serverNames { for serverName := range serverNames {
if !queues.getQueue(serverName).statistics.Blacklisted() { if queue := queues.getQueue(serverName); !queue.statistics.Blacklisted() {
queues.getQueue(serverName).wakeQueueIfNeeded() queue.wakeQueueIfNeeded()
} }
} }
}) })
@ -100,6 +101,16 @@ type SigningInfo struct {
PrivateKey ed25519.PrivateKey PrivateKey ed25519.PrivateKey
} }
type queuedPDU struct {
receipt *shared.Receipt
pdu *gomatrixserverlib.HeaderedEvent
}
type queuedEDU struct {
receipt *shared.Receipt
edu *gomatrixserverlib.EDU
}
func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue {
oqs.queuesMutex.Lock() oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock() defer oqs.queuesMutex.Unlock()
@ -112,8 +123,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
destination: destination, destination: destination,
client: oqs.client, client: oqs.client,
statistics: oqs.statistics.ForServer(destination), statistics: oqs.statistics.ForServer(destination),
notifyPDUs: make(chan bool, 1), notify: make(chan struct{}, 1),
notifyEDUs: make(chan bool, 1),
interruptBackoff: make(chan bool), interruptBackoff: make(chan bool),
signing: oqs.signing, signing: oqs.signing,
} }
@ -122,13 +132,23 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
return oq return oq
} }
type ErrorFederationDisabled struct {
Message string
}
func (e *ErrorFederationDisabled) Error() string {
return e.Message
}
// SendEvent sends an event to the destinations // SendEvent sends an event to the destinations
func (oqs *OutgoingQueues) SendEvent( func (oqs *OutgoingQueues) SendEvent(
ev *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, ev *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName,
destinations []gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName,
) error { ) error {
if oqs.disabled { if oqs.disabled {
return fmt.Errorf("federation is disabled") return &ErrorFederationDisabled{
Message: "Federation disabled",
}
} }
if origin != oqs.origin { if origin != oqs.origin {
// TODO: Support virtual hosting; gh issue #577. // TODO: Support virtual hosting; gh issue #577.
@ -178,7 +198,7 @@ func (oqs *OutgoingQueues) SendEvent(
} }
for destination := range destmap { for destination := range destmap {
oqs.getQueue(destination).sendEvent(nid) oqs.getQueue(destination).sendEvent(ev, nid)
} }
return nil return nil
@ -190,7 +210,9 @@ func (oqs *OutgoingQueues) SendEDU(
destinations []gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName,
) error { ) error {
if oqs.disabled { if oqs.disabled {
return fmt.Errorf("federation is disabled") return &ErrorFederationDisabled{
Message: "Federation disabled",
}
} }
if origin != oqs.origin { if origin != oqs.origin {
// TODO: Support virtual hosting; gh issue #577. // TODO: Support virtual hosting; gh issue #577.
@ -246,7 +268,7 @@ func (oqs *OutgoingQueues) SendEDU(
} }
for destination := range destmap { for destination := range destmap {
oqs.getQueue(destination).sendEDU(nid) oqs.getQueue(destination).sendEDU(e, nid)
} }
return nil return nil

View file

@ -36,14 +36,14 @@ type Database interface {
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, *shared.Receipt, error) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
GetNextTransactionEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) ([]*gomatrixserverlib.EDU, *shared.Receipt, error) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)

View file

@ -45,16 +45,10 @@ const insertQueuePDUSQL = "" +
const deleteQueuePDUSQL = "" + const deleteQueuePDUSQL = "" +
"DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid = ANY($2)" "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueuePDUNextTransactionIDSQL = "" + const selectQueuePDUsSQL = "" +
"SELECT transaction_id FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" ORDER BY transaction_id ASC" +
" LIMIT 1"
const selectQueuePDUsByTransactionSQL = "" +
"SELECT json_nid FROM federationsender_queue_pdus" + "SELECT json_nid FROM federationsender_queue_pdus" +
" WHERE server_name = $1 AND transaction_id = $2" + " WHERE server_name = $1" +
" LIMIT $3" " LIMIT $2"
const selectQueuePDUReferenceJSONCountSQL = "" + const selectQueuePDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" + "SELECT COUNT(*) FROM federationsender_queue_pdus" +
@ -71,8 +65,7 @@ type queuePDUsStatements struct {
db *sql.DB db *sql.DB
insertQueuePDUStmt *sql.Stmt insertQueuePDUStmt *sql.Stmt
deleteQueuePDUsStmt *sql.Stmt deleteQueuePDUsStmt *sql.Stmt
selectQueuePDUNextTransactionIDStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt
selectQueuePDUsByTransactionStmt *sql.Stmt
selectQueuePDUReferenceJSONCountStmt *sql.Stmt selectQueuePDUReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt selectQueuePDUsCountStmt *sql.Stmt
selectQueuePDUServerNamesStmt *sql.Stmt selectQueuePDUServerNamesStmt *sql.Stmt
@ -92,10 +85,7 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil { if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil {
return return
} }
if s.selectQueuePDUNextTransactionIDStmt, err = s.db.Prepare(selectQueuePDUNextTransactionIDSQL); err != nil { if s.selectQueuePDUsStmt, err = s.db.Prepare(selectQueuePDUsSQL); err != nil {
return
}
if s.selectQueuePDUsByTransactionStmt, err = s.db.Prepare(selectQueuePDUsByTransactionSQL); err != nil {
return return
} }
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
@ -137,18 +127,6 @@ func (s *queuePDUsStatements) DeleteQueuePDUs(
return err return err
} }
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (gomatrixserverlib.TransactionID, error) {
var transactionID gomatrixserverlib.TransactionID
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUNextTransactionIDStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID)
if err == sql.ErrNoRows {
return "", nil
}
return transactionID, err
}
func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64, ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) { ) (int64, error) {
@ -182,11 +160,10 @@ func (s *queuePDUsStatements) SelectQueuePDUCount(
func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID,
limit int, limit int,
) ([]int64, error) { ) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsByTransactionStmt) stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
rows, err := stmt.QueryContext(ctx, serverName, transactionID, limit) rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
@ -32,7 +33,7 @@ type Database struct {
} }
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
var d Database var d Database
var err error var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
@ -73,6 +74,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Cache: cache,
Writer: d.writer, Writer: d.writer,
FederationSenderJoinedHosts: joinedHosts, FederationSenderJoinedHosts: joinedHosts,
FederationSenderQueuePDUs: queuePDUs, FederationSenderQueuePDUs: queuePDUs,

View file

@ -17,17 +17,18 @@ package shared
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"github.com/matrix-org/dendrite/federationsender/storage/tables" "github.com/matrix-org/dendrite/federationsender/storage/tables"
"github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Cache caching.FederationSenderCache
Writer sqlutil.Writer Writer sqlutil.Writer
FederationSenderQueuePDUs tables.FederationSenderQueuePDUs FederationSenderQueuePDUs tables.FederationSenderQueuePDUs
FederationSenderQueueEDUs tables.FederationSenderQueueEDUs FederationSenderQueueEDUs tables.FederationSenderQueueEDUs
@ -44,16 +45,11 @@ type Database struct {
// to pass them back so that we can clean up if the transaction sends // to pass them back so that we can clean up if the transaction sends
// successfully. // successfully.
type Receipt struct { type Receipt struct {
nids []int64 nid int64
} }
func (e *Receipt) Empty() bool { func (r *Receipt) String() string {
return len(e.nids) == 0 return fmt.Sprintf("%d", r.nid)
}
func (e *Receipt) String() string {
j, _ := json.Marshal(e.nids)
return string(j)
} }
// UpdateRoom updates the joined hosts for a room and returns what the joined // UpdateRoom updates the joined hosts for a room and returns what the joined
@ -146,7 +142,7 @@ func (d *Database) StoreJSON(
return nil, fmt.Errorf("d.insertQueueJSON: %w", err) return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
} }
return &Receipt{ return &Receipt{
nids: []int64{nid}, nid: nid,
}, nil }, nil
} }

View file

@ -33,53 +33,56 @@ func (d *Database) AssociateEDUWithDestination(
receipt *Receipt, receipt *Receipt,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, nid := range receipt.nids {
if err := d.FederationSenderQueueEDUs.InsertQueueEDU( if err := d.FederationSenderQueueEDUs.InsertQueueEDU(
ctx, // context ctx, // context
txn, // SQL transaction txn, // SQL transaction
"", // TODO: EDU type for coalescing "", // TODO: EDU type for coalescing
serverName, // destination server name serverName, // destination server name
nid, // NID from the federationsender_queue_json table receipt.nid, // NID from the federationsender_queue_json table
); err != nil { ); err != nil {
return fmt.Errorf("InsertQueueEDU: %w", err) return fmt.Errorf("InsertQueueEDU: %w", err)
} }
}
return nil return nil
}) })
} }
// GetNextTransactionEDUs retrieves events from the database for // GetNextTransactionEDUs retrieves events from the database for
// the next pending transaction, up to the limit specified. // the next pending transaction, up to the limit specified.
func (d *Database) GetNextTransactionEDUs( func (d *Database) GetPendingEDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) ( ) (
edus []*gomatrixserverlib.EDU, edus map[*Receipt]*gomatrixserverlib.EDU,
receipt *Receipt,
err error, err error,
) { ) {
edus = make(map[*Receipt]*gomatrixserverlib.EDU)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nids, err := d.FederationSenderQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) nids, err := d.FederationSenderQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueueEDUs: %w", err) return fmt.Errorf("SelectQueueEDUs: %w", err)
} }
receipt = &Receipt{ retrieve := make([]int64, 0, len(nids))
nids: nids, for _, nid := range nids {
if edu, ok := d.Cache.GetFederationSenderQueuedEDU(nid); ok {
edus[&Receipt{nid}] = edu
} else {
retrieve = append(retrieve, nid)
}
} }
blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, nids) blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, retrieve)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueueJSON: %w", err) return fmt.Errorf("SelectQueueJSON: %w", err)
} }
for _, blob := range blobs { for nid, blob := range blobs {
var event gomatrixserverlib.EDU var event gomatrixserverlib.EDU
if err := json.Unmarshal(blob, &event); err != nil { if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err) return fmt.Errorf("json.Unmarshal: %w", err)
} }
edus = append(edus, &event) edus[&Receipt{nid}] = &event
} }
return nil return nil
@ -92,25 +95,31 @@ func (d *Database) GetNextTransactionEDUs(
func (d *Database) CleanEDUs( func (d *Database) CleanEDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
receipt *Receipt, receipts []*Receipt,
) error { ) error {
if receipt == nil { if len(receipts) == 0 {
return errors.New("expected receipt") return errors.New("expected receipt")
} }
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationSenderQueueEDUs.DeleteQueueEDUs(ctx, txn, serverName, receipt.nids); err != nil { if err := d.FederationSenderQueueEDUs.DeleteQueueEDUs(ctx, txn, serverName, nids); err != nil {
return err return err
} }
var deleteNIDs []int64 var deleteNIDs []int64
for _, nid := range receipt.nids { for _, nid := range nids {
count, err := d.FederationSenderQueueEDUs.SelectQueueEDUReferenceJSONCount(ctx, txn, nid) count, err := d.FederationSenderQueueEDUs.SelectQueueEDUReferenceJSONCount(ctx, txn, nid)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueueEDUReferenceJSONCount: %w", err) return fmt.Errorf("SelectQueueEDUReferenceJSONCount: %w", err)
} }
if count == 0 { if count == 0 {
deleteNIDs = append(deleteNIDs, nid) deleteNIDs = append(deleteNIDs, nid)
d.Cache.EvictFederationSenderQueuedEDU(nid)
} }
} }

View file

@ -34,31 +34,27 @@ func (d *Database) AssociatePDUWithDestination(
receipt *Receipt, receipt *Receipt,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, nid := range receipt.nids {
if err := d.FederationSenderQueuePDUs.InsertQueuePDU( if err := d.FederationSenderQueuePDUs.InsertQueuePDU(
ctx, // context ctx, // context
txn, // SQL transaction txn, // SQL transaction
transactionID, // transaction ID transactionID, // transaction ID
serverName, // destination server name serverName, // destination server name
nid, // NID from the federationsender_queue_json table receipt.nid, // NID from the federationsender_queue_json table
); err != nil { ); err != nil {
return fmt.Errorf("InsertQueuePDU: %w", err) return fmt.Errorf("InsertQueuePDU: %w", err)
} }
}
return nil return nil
}) })
} }
// GetNextTransactionPDUs retrieves events from the database for // GetNextTransactionPDUs retrieves events from the database for
// the next pending transaction, up to the limit specified. // the next pending transaction, up to the limit specified.
func (d *Database) GetNextTransactionPDUs( func (d *Database) GetPendingPDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) ( ) (
transactionID gomatrixserverlib.TransactionID, events map[*Receipt]*gomatrixserverlib.HeaderedEvent,
events []*gomatrixserverlib.HeaderedEvent,
receipt *Receipt,
err error, err error,
) { ) {
// Strictly speaking this doesn't need to be using the writer // Strictly speaking this doesn't need to be using the writer
@ -66,36 +62,34 @@ func (d *Database) GetNextTransactionPDUs(
// a guarantee of transactional isolation, it's actually useful // a guarantee of transactional isolation, it's actually useful
// to know in SQLite mode that nothing else is trying to modify // to know in SQLite mode that nothing else is trying to modify
// the database. // the database.
events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
transactionID, err = d.FederationSenderQueuePDUs.SelectQueuePDUNextTransactionID(ctx, txn, serverName) nids, err := d.FederationSenderQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit)
if err != nil {
return fmt.Errorf("SelectQueuePDUNextTransactionID: %w", err)
}
if transactionID == "" {
return nil
}
nids, err := d.FederationSenderQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, transactionID, limit)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueuePDUs: %w", err) return fmt.Errorf("SelectQueuePDUs: %w", err)
} }
receipt = &Receipt{ retrieve := make([]int64, 0, len(nids))
nids: nids, for _, nid := range nids {
if event, ok := d.Cache.GetFederationSenderQueuedPDU(nid); ok {
events[&Receipt{nid}] = event
} else {
retrieve = append(retrieve, nid)
}
} }
blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, nids) blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, retrieve)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueueJSON: %w", err) return fmt.Errorf("SelectQueueJSON: %w", err)
} }
for _, blob := range blobs { for nid, blob := range blobs {
var event gomatrixserverlib.HeaderedEvent var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(blob, &event); err != nil { if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err) return fmt.Errorf("json.Unmarshal: %w", err)
} }
events = append(events, &event) events[&Receipt{nid}] = &event
d.Cache.StoreFederationSenderQueuedPDU(nid, &event)
} }
return nil return nil
@ -109,25 +103,31 @@ func (d *Database) GetNextTransactionPDUs(
func (d *Database) CleanPDUs( func (d *Database) CleanPDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
receipt *Receipt, receipts []*Receipt,
) error { ) error {
if receipt == nil { if len(receipts) == 0 {
return errors.New("expected receipt") return errors.New("expected receipt")
} }
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationSenderQueuePDUs.DeleteQueuePDUs(ctx, txn, serverName, receipt.nids); err != nil { if err := d.FederationSenderQueuePDUs.DeleteQueuePDUs(ctx, txn, serverName, nids); err != nil {
return err return err
} }
var deleteNIDs []int64 var deleteNIDs []int64
for _, nid := range receipt.nids { for _, nid := range nids {
count, err := d.FederationSenderQueuePDUs.SelectQueuePDUReferenceJSONCount(ctx, txn, nid) count, err := d.FederationSenderQueuePDUs.SelectQueuePDUReferenceJSONCount(ctx, txn, nid)
if err != nil { if err != nil {
return fmt.Errorf("SelectQueuePDUReferenceJSONCount: %w", err) return fmt.Errorf("SelectQueuePDUReferenceJSONCount: %w", err)
} }
if count == 0 { if count == 0 {
deleteNIDs = append(deleteNIDs, nid) deleteNIDs = append(deleteNIDs, nid)
d.Cache.EvictFederationSenderQueuedPDU(nid)
} }
} }

View file

@ -53,10 +53,10 @@ const selectQueueNextTransactionIDSQL = "" +
" ORDER BY transaction_id ASC" + " ORDER BY transaction_id ASC" +
" LIMIT 1" " LIMIT 1"
const selectQueuePDUsByTransactionSQL = "" + const selectQueuePDUsSQL = "" +
"SELECT json_nid FROM federationsender_queue_pdus" + "SELECT json_nid FROM federationsender_queue_pdus" +
" WHERE server_name = $1 AND transaction_id = $2" + " WHERE server_name = $1" +
" LIMIT $3" " LIMIT $2"
const selectQueuePDUsReferenceJSONCountSQL = "" + const selectQueuePDUsReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" + "SELECT COUNT(*) FROM federationsender_queue_pdus" +
@ -73,7 +73,7 @@ type queuePDUsStatements struct {
db *sql.DB db *sql.DB
insertQueuePDUStmt *sql.Stmt insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsByTransactionStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt
selectQueueReferenceJSONCountStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt selectQueuePDUsCountStmt *sql.Stmt
selectQueueServerNamesStmt *sql.Stmt selectQueueServerNamesStmt *sql.Stmt
@ -97,7 +97,7 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil {
return return
} }
if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil {
return return
} }
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
@ -193,11 +193,10 @@ func (s *queuePDUsStatements) SelectQueuePDUCount(
func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID,
limit int, limit int,
) ([]int64, error) { ) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsByTransactionStmt) stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
rows, err := stmt.QueryContext(ctx, serverName, transactionID, limit) rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -21,6 +21,7 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
@ -34,7 +35,7 @@ type Database struct {
} }
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
var d Database var d Database
var err error var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
@ -75,6 +76,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Cache: cache,
Writer: d.writer, Writer: d.writer,
FederationSenderJoinedHosts: joinedHosts, FederationSenderJoinedHosts: joinedHosts,
FederationSenderQueuePDUs: queuePDUs, FederationSenderQueuePDUs: queuePDUs,

View file

@ -21,16 +21,17 @@ import (
"github.com/matrix-org/dendrite/federationsender/storage/postgres" "github.com/matrix-org/dendrite/federationsender/storage/postgres"
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3" "github.com/matrix-org/dendrite/federationsender/storage/sqlite3"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties) return sqlite3.NewDatabase(dbProperties, cache)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties) return postgres.NewDatabase(dbProperties, cache)
default: default:
return nil, fmt.Errorf("unexpected database type") return nil, fmt.Errorf("unexpected database type")
} }

View file

@ -18,14 +18,15 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3" "github.com/matrix-org/dendrite/federationsender/storage/sqlite3"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties) return sqlite3.NewDatabase(dbProperties, cache)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation") return nil, fmt.Errorf("can't use Postgres implementation")
default: default:

View file

@ -25,10 +25,9 @@ import (
type FederationSenderQueuePDUs interface { type FederationSenderQueuePDUs interface {
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueuePDUNextTransactionID(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (gomatrixserverlib.TransactionID, error)
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, limit int) ([]int64, error) SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
} }

10
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
github.com/matrix-org/gomatrixserverlib v0.0.0-20201202134418-2ba106a5bca3 github.com/matrix-org/gomatrixserverlib v0.0.0-20201209172200-eb6a8903f9fb
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.2 github.com/mattn/go-sqlite3 v1.14.2
@ -32,15 +32,17 @@ require (
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pressly/goose v2.7.0-rc5+incompatible github.com/pressly/goose v2.7.0-rc5+incompatible
github.com/prometheus/client_golang v1.7.1 github.com/prometheus/client_golang v1.7.1
github.com/sirupsen/logrus v1.6.0 github.com/sirupsen/logrus v1.7.0
github.com/tidwall/gjson v1.6.3 github.com/tidwall/gjson v1.6.3
github.com/tidwall/sjson v1.1.1 github.com/tidwall/match v1.0.2 // indirect
github.com/tidwall/sjson v1.1.2
github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-client-go v2.25.0+incompatible
github.com/uber/jaeger-lib v2.2.0+incompatible github.com/uber/jaeger-lib v2.2.0+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20201006093556-760d9a7fd5ee github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20201006093556-760d9a7fd5ee
go.uber.org/atomic v1.6.0 go.uber.org/atomic v1.6.0
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9
golang.org/x/net v0.0.0-20200528225125-3c3fba18258b golang.org/x/net v0.0.0-20200528225125-3c3fba18258b
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect
gopkg.in/h2non/bimg.v1 v1.1.4 gopkg.in/h2non/bimg.v1 v1.1.4
gopkg.in/yaml.v2 v2.3.0 gopkg.in/yaml.v2 v2.3.0
) )

25
go.sum
View file

@ -301,8 +301,6 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d h1:68u9r4wEvL3gYg2jvAOgROwZ3H+Y3hIDk4tbbmIjcYQ= github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d h1:68u9r4wEvL3gYg2jvAOgROwZ3H+Y3hIDk4tbbmIjcYQ=
github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d/go.mod h1:5Ky9EC2xfoUKUor0Hjgi2BJhCSXJfMOFlmyYrVKGQMk= github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d/go.mod h1:5Ky9EC2xfoUKUor0Hjgi2BJhCSXJfMOFlmyYrVKGQMk=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@ -569,8 +567,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20201202134418-2ba106a5bca3 h1:+45Q/5FybBhHPMr10YdzJNFYO/6RRgkBcZbMzIRq5Ck= github.com/matrix-org/gomatrixserverlib v0.0.0-20201209172200-eb6a8903f9fb h1:UlhiSebJupQ+qAM93cdVGg4nAJ6bnxwAA5/EBygtYoo=
github.com/matrix-org/gomatrixserverlib v0.0.0-20201202134418-2ba106a5bca3/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20201209172200-eb6a8903f9fb/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
@ -779,8 +777,8 @@ github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5k
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/smola/gocompat v0.2.0/go.mod h1:1B0MlxbmoZNo3h8guHp8HztB3BSYR5itql9qtVc0ypY= github.com/smola/gocompat v0.2.0/go.mod h1:1B0MlxbmoZNo3h8guHp8HztB3BSYR5itql9qtVc0ypY=
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE=
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
@ -812,10 +810,13 @@ github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpP
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc= github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc=
github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
github.com/tidwall/gjson v1.6.1/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0=
github.com/tidwall/gjson v1.6.3 h1:aHoiiem0dr7GHkW001T1SMTJ7X5PvyekH5WX0whWGnI= github.com/tidwall/gjson v1.6.3 h1:aHoiiem0dr7GHkW001T1SMTJ7X5PvyekH5WX0whWGnI=
github.com/tidwall/gjson v1.6.3/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0= github.com/tidwall/gjson v1.6.3/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0=
github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc= github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc=
github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
github.com/tidwall/match v1.0.2 h1:uuqvHuBGSedK7awZ2YoAtpnimfwBGFjHuWLuLqQj+bU=
github.com/tidwall/match v1.0.2/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8= github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8=
github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
@ -823,8 +824,8 @@ github.com/tidwall/pretty v1.0.2 h1:Z7S3cePv9Jwm1KwS0513MRaoUe3S01WPbLNV40pwWZU=
github.com/tidwall/pretty v1.0.2/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.2/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/sjson v1.0.3 h1:DeF+0LZqvIt4fKYw41aPB29ZGlvwVkHKktoXJ1YW9Y8= github.com/tidwall/sjson v1.0.3 h1:DeF+0LZqvIt4fKYw41aPB29ZGlvwVkHKktoXJ1YW9Y8=
github.com/tidwall/sjson v1.0.3/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= github.com/tidwall/sjson v1.0.3/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y=
github.com/tidwall/sjson v1.1.1 h1:7h1vk049Jnd5EH9NyzNiEuwYW4b5qgreBbqRC19AS3U= github.com/tidwall/sjson v1.1.2 h1:NC5okI+tQ8OG/oyzchvwXXxRxCV/FVdhODbPKkQ25jQ=
github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs= github.com/tidwall/sjson v1.1.2/go.mod h1:SEzaDwxiPzKzNfUEO4HbYF/m4UCSJDsGgNqsS1LvdoY=
github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U= github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U=
github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw= github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw=
@ -905,8 +906,8 @@ golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5 h1:Q7tZBpemrlsc2I7IyODzht
golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@ -982,6 +983,7 @@ golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -994,6 +996,9 @@ golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80=
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View file

@ -0,0 +1,67 @@
package caching
import (
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
const (
FederationEventCacheName = "federation_event"
FederationEventCacheMaxEntries = 256
FederationEventCacheMutable = true // to allow use of Unset only
)
// FederationSenderCache contains the subset of functions needed for
// a federation event cache.
type FederationSenderCache interface {
GetFederationSenderQueuedPDU(eventNID int64) (event *gomatrixserverlib.HeaderedEvent, ok bool)
StoreFederationSenderQueuedPDU(eventNID int64, event *gomatrixserverlib.HeaderedEvent)
EvictFederationSenderQueuedPDU(eventNID int64)
GetFederationSenderQueuedEDU(eventNID int64) (event *gomatrixserverlib.EDU, ok bool)
StoreFederationSenderQueuedEDU(eventNID int64, event *gomatrixserverlib.EDU)
EvictFederationSenderQueuedEDU(eventNID int64)
}
func (c Caches) GetFederationSenderQueuedPDU(eventNID int64) (*gomatrixserverlib.HeaderedEvent, bool) {
key := fmt.Sprintf("%d", eventNID)
val, found := c.FederationEvents.Get(key)
if found && val != nil {
if event, ok := val.(*gomatrixserverlib.HeaderedEvent); ok {
return event, true
}
}
return nil, false
}
func (c Caches) StoreFederationSenderQueuedPDU(eventNID int64, event *gomatrixserverlib.HeaderedEvent) {
key := fmt.Sprintf("%d", eventNID)
c.FederationEvents.Set(key, event)
}
func (c Caches) EvictFederationSenderQueuedPDU(eventNID int64) {
key := fmt.Sprintf("%d", eventNID)
c.FederationEvents.Unset(key)
}
func (c Caches) GetFederationSenderQueuedEDU(eventNID int64) (*gomatrixserverlib.EDU, bool) {
key := fmt.Sprintf("%d", eventNID)
val, found := c.FederationEvents.Get(key)
if found && val != nil {
if event, ok := val.(*gomatrixserverlib.EDU); ok {
return event, true
}
}
return nil, false
}
func (c Caches) StoreFederationSenderQueuedEDU(eventNID int64, event *gomatrixserverlib.EDU) {
key := fmt.Sprintf("%d", eventNID)
c.FederationEvents.Set(key, event)
}
func (c Caches) EvictFederationSenderQueuedEDU(eventNID int64) {
key := fmt.Sprintf("%d", eventNID)
c.FederationEvents.Unset(key)
}

View file

@ -10,6 +10,7 @@ type Caches struct {
RoomServerEventTypeNIDs Cache // RoomServerNIDsCache RoomServerEventTypeNIDs Cache // RoomServerNIDsCache
RoomServerRoomNIDs Cache // RoomServerNIDsCache RoomServerRoomNIDs Cache // RoomServerNIDsCache
RoomServerRoomIDs Cache // RoomServerNIDsCache RoomServerRoomIDs Cache // RoomServerNIDsCache
FederationEvents Cache // FederationEventsCache
} }
// Cache is the interface that an implementation must satisfy. // Cache is the interface that an implementation must satisfy.

View file

@ -63,6 +63,15 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
federationEvents, err := NewInMemoryLRUCachePartition(
FederationEventCacheName,
FederationEventCacheMutable,
FederationEventCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
return &Caches{ return &Caches{
RoomVersions: roomVersions, RoomVersions: roomVersions,
ServerKeys: serverKeys, ServerKeys: serverKeys,
@ -70,6 +79,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
RoomServerEventTypeNIDs: roomServerEventTypeNIDs, RoomServerEventTypeNIDs: roomServerEventTypeNIDs,
RoomServerRoomNIDs: roomServerRoomNIDs, RoomServerRoomNIDs: roomServerRoomNIDs,
RoomServerRoomIDs: roomServerRoomIDs, RoomServerRoomIDs: roomServerRoomIDs,
FederationEvents: federationEvents,
}, nil }, nil
} }

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 3 VersionMinor = 3
VersionPatch = 2 VersionPatch = 3
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -138,6 +138,15 @@ type RoomserverInternalAPI interface {
response *QueryStateAndAuthChainResponse, response *QueryStateAndAuthChainResponse,
) error ) error
// QueryAuthChain returns the entire auth chain for the event IDs given.
// The response includes the events in the request.
// Omits without error for any missing auth events. There will be no duplicates.
QueryAuthChain(
ctx context.Context,
request *QueryAuthChainRequest,
response *QueryAuthChainResponse,
) error
// QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from // QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from
// the response. // the response.
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error

View file

@ -334,6 +334,16 @@ func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Conte
return err return err
} }
func (t *RoomserverInternalAPITrace) QueryAuthChain(
ctx context.Context,
request *QueryAuthChainRequest,
response *QueryAuthChainResponse,
) error {
err := t.Impl.QueryAuthChain(ctx, request, response)
util.GetLogger(ctx).WithError(err).Infof("QueryAuthChain req=%+v res=%+v", js(request), js(response))
return err
}
func js(thing interface{}) string { func js(thing interface{}) string {
b, err := json.Marshal(thing) b, err := json.Marshal(thing)
if err != nil { if err != nil {

View file

@ -275,6 +275,14 @@ type QueryPublishedRoomsResponse struct {
RoomIDs []string RoomIDs []string
} }
type QueryAuthChainRequest struct {
EventIDs []string
}
type QueryAuthChainResponse struct {
AuthChain []*gomatrixserverlib.HeaderedEvent
}
type QuerySharedUsersRequest struct { type QuerySharedUsersRequest struct {
UserID string UserID string
ExcludeRoomIDs []string ExcludeRoomIDs []string

View file

@ -62,10 +62,10 @@ func (w *inputWorker) start() {
for { for {
select { select {
case task := <-w.input: case task := <-w.input:
hooks.Run(hooks.KindNewEventReceived, &task.event.Event) hooks.Run(hooks.KindNewEventReceived, task.event.Event)
_, task.err = w.r.processRoomEvent(task.ctx, task.event) _, task.err = w.r.processRoomEvent(task.ctx, task.event)
if task.err == nil { if task.err == nil {
hooks.Run(hooks.KindNewEventPersisted, &task.event.Event) hooks.Run(hooks.KindNewEventPersisted, task.event.Event)
} }
task.wg.Done() task.wg.Done()
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):

View file

@ -285,16 +285,16 @@ func (u *latestEventsUpdater) calculateLatest(
// then do nothing - it's not a candidate to be a new extremity if // then do nothing - it's not a candidate to be a new extremity if
// it has been referenced. // it has been referenced.
if _, ok := existingPrevs[newEvent.EventID()]; ok { if _, ok := existingPrevs[newEvent.EventID()]; ok {
u.latest = oldLatest
return false, nil return false, nil
} }
// If the "new" event is already a forward extremity then stop, as // If the "new" event is already a forward extremity then stop, as
// nothing changes. // nothing changes.
for _, event := range events { if _, ok := existingRefs[newEvent.EventID()]; ok {
if event.EventID() == newEvent.EventID() { u.latest = oldLatest
return false, nil return false, nil
} }
}
// Include our new event in the extremities. // Include our new event in the extremities.
newLatest := []types.StateAtEventAndReference{newStateAndRef} newLatest := []types.StateAtEventAndReference{newStateAndRef}

View file

@ -718,3 +718,16 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS
res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID) res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID)
return nil return nil
} }
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
chain, err := getAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs)
if err != nil {
return err
}
hchain := make([]*gomatrixserverlib.HeaderedEvent, len(chain))
for i := range chain {
hchain[i] = chain[i].Headered(chain[i].Version())
}
res.AuthChain = hchain
return nil
}

View file

@ -55,6 +55,7 @@ const (
RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers" RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers"
RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers"
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
) )
type httpRoomserverInternalAPI struct { type httpRoomserverInternalAPI struct {
@ -515,6 +516,16 @@ func (h *httpRoomserverInternalAPI) QueryKnownUsers(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpRoomserverInternalAPI) QueryAuthChain(
ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAuthChain")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryAuthChainPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
) error { ) error {

View file

@ -465,4 +465,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(RoomserverQueryAuthChainPath,
httputil.MakeInternalAPI("queryAuthChain", func(req *http.Request) util.JSONResponse {
request := api.QueryAuthChainRequest{}
response := api.QueryAuthChainResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryAuthChain(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -309,6 +309,10 @@ func (d *Database) Events(
if err != nil { if err != nil {
return nil, err return nil, err
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
results := make([]types.Event, len(eventJSONs)) results := make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs { for i, eventJSON := range eventJSONs {
var roomNID types.RoomNID var roomNID types.RoomNID
@ -328,8 +332,8 @@ func (d *Database) Events(
return nil, err return nil, err
} }
} }
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( result.Event, err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
eventJSON.EventJSON, false, roomVersion, eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -779,6 +783,7 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
// GetStateEvent returns the current state event of a given type for a given room with a given state key // GetStateEvent returns the current state event of a given type for a given room with a given state key
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
// nolint:gocyclo
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
roomInfo, err := d.RoomInfo(ctx, roomID) roomInfo, err := d.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
@ -800,6 +805,16 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
if err != nil { if err != nil {
return nil, err return nil, err
} }
var eventNIDs []types.EventNID
for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
eventNIDs = append(eventNIDs, e.EventNID)
}
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
// return the event requested // return the event requested
for _, e := range entries { for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
@ -810,7 +825,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
if len(data) == 0 { if len(data) == 0 {
return nil, fmt.Errorf("GetStateEvent: no json for event nid %d", e.EventNID) return nil, fmt.Errorf("GetStateEvent: no json for event nid %d", e.EventNID)
} }
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(data[0].EventJSON, false, roomInfo.RoomVersion) ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[e.EventNID], data[0].EventJSON, false, roomInfo.RoomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -921,7 +936,10 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
} }
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err) return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
@ -929,7 +947,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
result := make([]tables.StrippedEvent, len(events)) result := make([]tables.StrippedEvent, len(events))
for i := range events { for i := range events {
roomVer := eventNIDToVer[events[i].EventNID] roomVer := eventNIDToVer[events[i].EventNID]
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(events[i].EventJSON, false, roomVer) ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[events[i].EventNID], events[i].EventJSON, false, roomVer)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event NID %v : %w", events[i].EventNID, err) return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event NID %v : %w", events[i].EventNID, err)
} }

View file

@ -18,10 +18,13 @@ package msc2836
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/sha256"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"sort"
"strings"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
@ -37,12 +40,11 @@ import (
const ( const (
constRelType = "m.reference" constRelType = "m.reference"
constRoomIDKey = "relationship_room_id"
constRoomServers = "relationship_servers"
) )
type EventRelationshipRequest struct { type EventRelationshipRequest struct {
EventID string `json:"event_id"` EventID string `json:"event_id"`
RoomID string `json:"room_id"`
MaxDepth int `json:"max_depth"` MaxDepth int `json:"max_depth"`
MaxBreadth int `json:"max_breadth"` MaxBreadth int `json:"max_breadth"`
Limit int `json:"limit"` Limit int `json:"limit"`
@ -52,7 +54,6 @@ type EventRelationshipRequest struct {
IncludeChildren bool `json:"include_children"` IncludeChildren bool `json:"include_children"`
Direction string `json:"direction"` Direction string `json:"direction"`
Batch string `json:"batch"` Batch string `json:"batch"`
AutoJoin bool `json:"auto_join"`
} }
func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) { func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) {
@ -81,8 +82,16 @@ type EventRelationshipResponse struct {
Limited bool `json:"limited"` Limited bool `json:"limited"`
} }
func toClientResponse(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) *EventRelationshipResponse {
out := &EventRelationshipResponse{
Events: gomatrixserverlib.ToClientEvents(res.Events, gomatrixserverlib.FormatAll),
Limited: res.Limited,
NextBatch: res.NextBatch,
}
return out
}
// Enable this MSC // Enable this MSC
// nolint:gocyclo
func Enable( func Enable(
base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI, base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI,
userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
@ -96,63 +105,22 @@ func Enable(
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
hookErr := db.StoreRelation(context.Background(), he) hookErr := db.StoreRelation(context.Background(), he)
if hookErr != nil { if hookErr != nil {
util.GetLogger(context.Background()).WithError(hookErr).Error( util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
"failed to StoreRelation", "failed to StoreRelation",
) )
} }
}) // we need to update child metadata here as well as after doing remote /event_relationships requests
hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { // so we catch child metadata originating from /send transactions
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) hookErr = db.UpdateChildMetadata(context.Background(), he)
ctx := context.Background() if hookErr != nil {
// we only inject metadata for events our server sends util.GetLogger(context.Background()).WithError(err).WithField("event_id", he.EventID()).Warn(
userID := he.Sender() "failed to update child metadata for event",
_, domain, err := gomatrixserverlib.SplitID('@', userID) )
if err != nil {
return
}
if domain != base.Cfg.Global.ServerName {
return
}
// if this event has an m.relationship, add on the room_id and servers to unsigned
parent, child, relType := parentChildEventIDs(he)
if parent == "" || child == "" || relType == "" {
return
}
event, joinedToRoom := getEventIfVisible(ctx, rsAPI, parent, userID)
if !joinedToRoom {
return
}
err = he.SetUnsignedField(constRoomIDKey, event.RoomID())
if err != nil {
util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField")
return
}
var servers []gomatrixserverlib.ServerName
if fsAPI != nil {
var res fs.QueryJoinedHostServerNamesInRoomResponse
err = fsAPI.QueryJoinedHostServerNamesInRoom(ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
RoomID: event.RoomID(),
}, &res)
if err != nil {
util.GetLogger(context.Background()).WithError(err).Warn("Failed to QueryJoinedHostServerNamesInRoom")
return
}
servers = res.ServerNames
} else {
servers = []gomatrixserverlib.ServerName{
base.Cfg.Global.ServerName,
}
}
err = he.SetUnsignedField(constRoomServers, servers)
if err != nil {
util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField")
return
} }
}) })
base.PublicClientAPIMux.Handle("/unstable/event_relationships", base.PublicClientAPIMux.Handle("/unstable/event_relationships",
httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI)), httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI, fsAPI)),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI(
@ -163,7 +131,7 @@ func Enable(
if fedReq == nil { if fedReq == nil {
return errResp return errResp
} }
return federatedEventRelationship(req.Context(), fedReq, db, rsAPI) return federatedEventRelationship(req.Context(), fedReq, db, rsAPI, fsAPI)
}, },
)).Methods(http.MethodPost, http.MethodOptions) )).Methods(http.MethodPost, http.MethodOptions)
return nil return nil
@ -175,10 +143,15 @@ type reqCtx struct {
db Database db Database
req *EventRelationshipRequest req *EventRelationshipRequest
userID string userID string
roomVersion gomatrixserverlib.RoomVersion
// federated request args
isFederatedRequest bool isFederatedRequest bool
serverName gomatrixserverlib.ServerName
fsAPI fs.FederationSenderInternalAPI
} }
func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
return func(req *http.Request, device *userapi.Device) util.JSONResponse { return func(req *http.Request, device *userapi.Device) util.JSONResponse {
relation, err := NewEventRelationshipRequest(req.Body) relation, err := NewEventRelationshipRequest(req.Body)
if err != nil { if err != nil {
@ -193,6 +166,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
req: relation, req: relation,
userID: device.UserID, userID: device.UserID,
rsAPI: rsAPI, rsAPI: rsAPI,
fsAPI: fsAPI,
isFederatedRequest: false, isFederatedRequest: false,
db: db, db: db,
} }
@ -203,12 +177,14 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: res, JSON: toClientResponse(res),
} }
} }
} }
func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI) util.JSONResponse { func federatedEventRelationship(
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI,
) util.JSONResponse {
relation, err := NewEventRelationshipRequest(bytes.NewBuffer(fedReq.Content())) relation, err := NewEventRelationshipRequest(bytes.NewBuffer(fedReq.Content()))
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to decode HTTP request as JSON") util.GetLogger(ctx).WithError(err).Error("failed to decode HTTP request as JSON")
@ -220,15 +196,41 @@ func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.F
rc := reqCtx{ rc := reqCtx{
ctx: ctx, ctx: ctx,
req: relation, req: relation,
userID: "",
rsAPI: rsAPI, rsAPI: rsAPI,
isFederatedRequest: true,
db: db, db: db,
// federation args
isFederatedRequest: true,
fsAPI: fsAPI,
serverName: fedReq.Origin(),
} }
res, resErr := rc.process() res, resErr := rc.process()
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
// add auth chain information
requiredAuthEventsSet := make(map[string]bool)
var requiredAuthEvents []string
for _, ev := range res.Events {
for _, a := range ev.AuthEventIDs() {
if requiredAuthEventsSet[a] {
continue
}
requiredAuthEvents = append(requiredAuthEvents, a)
requiredAuthEventsSet[a] = true
}
}
var queryRes roomserver.QueryAuthChainResponse
err = rsAPI.QueryAuthChain(ctx, &roomserver.QueryAuthChainRequest{
EventIDs: requiredAuthEvents,
}, &queryRes)
if err != nil {
// they may already have the auth events so don't fail this request
util.GetLogger(ctx).WithError(err).Error("Failed to QueryAuthChain")
}
res.AuthChain = make([]*gomatrixserverlib.Event, len(queryRes.AuthChain))
for i := range queryRes.AuthChain {
res.AuthChain[i] = queryRes.AuthChain[i].Unwrap()
}
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
@ -236,18 +238,25 @@ func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.F
} }
} }
func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) { // nolint:gocyclo
var res EventRelationshipResponse func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) {
var res gomatrixserverlib.MSC2836EventRelationshipsResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent var returnEvents []*gomatrixserverlib.HeaderedEvent
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue. // Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
// We should have the event being referenced so don't give any claimed room ID / servers event := rc.getLocalEvent(rc.req.EventID)
event := rc.getEventIfVisible(rc.req.EventID, "", nil)
if event == nil { if event == nil {
event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID)
}
if rc.req.RoomID == "" && event != nil {
rc.req.RoomID = event.RoomID()
}
if event == nil || !rc.authorisedToSeeEvent(event) {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: 403, Code: 403,
JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"),
} }
} }
rc.roomVersion = event.Version()
// Retrieve the event. Add it to response array. // Retrieve the event. Add it to response array.
returnEvents = append(returnEvents, event) returnEvents = append(returnEvents, event)
@ -282,29 +291,122 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
) )
returnEvents = append(returnEvents, events...) returnEvents = append(returnEvents, events...)
} }
res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents)) res.Events = make([]*gomatrixserverlib.Event, len(returnEvents))
for i, ev := range returnEvents { for i, ev := range returnEvents {
res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll) // for each event, extract the children_count | hash and add it as unsigned data.
rc.addChildMetadata(ev)
res.Events[i] = ev.Unwrap()
} }
res.Limited = remaining == 0 || walkLimited res.Limited = remaining == 0 || walkLimited
return &res, nil return &res, nil
} }
// fetchUnknownEvent retrieves an unknown event from the room specified. This server must
// be joined to the room in question. This has the side effect of injecting surround threaded
// events into the roomserver.
func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.HeaderedEvent {
if rc.isFederatedRequest || roomID == "" {
// we don't do fed hits for fed requests, and we can't ask servers without a room ID!
return nil
}
logger := util.GetLogger(rc.ctx).WithField("room_id", roomID)
// if they supplied a room_id, check the room exists.
var queryVerRes roomserver.QueryRoomVersionForRoomResponse
err := rc.rsAPI.QueryRoomVersionForRoom(rc.ctx, &roomserver.QueryRoomVersionForRoomRequest{
RoomID: roomID,
}, &queryVerRes)
if err != nil {
logger.WithError(err).Warn("failed to query room version for room, does this room exist?")
return nil
}
// check the user is joined to that room
var queryMemRes roomserver.QueryMembershipForUserResponse
err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: roomID,
UserID: rc.userID,
}, &queryMemRes)
if err != nil {
logger.WithError(err).Warn("failed to query membership for user in room")
return nil
}
if !queryMemRes.IsInRoom {
return nil
}
// ask one of the servers in the room for the event
var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
err = rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
RoomID: roomID,
}, &queryRes)
if err != nil {
logger.WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom")
return nil
}
// query up to 5 servers
serversToQuery := queryRes.ServerNames
if len(serversToQuery) > 5 {
serversToQuery = serversToQuery[:5]
}
// fetch the event, along with some of the surrounding thread (if it's threaded) and the auth chain.
// Inject the response into the roomserver to remember the event across multiple calls and to set
// unexplored flags correctly.
for _, srv := range serversToQuery {
res, err := rc.MSC2836EventRelationships(eventID, srv, queryVerRes.RoomVersion)
if err != nil {
continue
}
rc.injectResponseToRoomserver(res)
for _, ev := range res.Events {
if ev.EventID() == eventID {
return ev.Headered(ev.Version())
}
}
}
logger.WithField("servers", serversToQuery).Warn("failed to query event relationships")
return nil
}
// If include_parent: true and there is a valid m.relationship field in the event, // If include_parent: true and there is a valid m.relationship field in the event,
// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. // retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array.
func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) { func (rc *reqCtx) includeParent(childEvent *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) {
parentID, _, _ := parentChildEventIDs(event) parentID, _, _ := parentChildEventIDs(childEvent)
if parentID == "" { if parentID == "" {
return nil return nil
} }
claimedRoomID, claimedServers := roomIDAndServers(event) return rc.lookForEvent(parentID)
return rc.getEventIfVisible(parentID, claimedRoomID, claimedServers)
} }
// If include_children: true, lookup all events which have event_id as an m.relationship // If include_children: true, lookup all events which have event_id as an m.relationship
// Apply history visibility checks to all these events and add the ones which pass into the response array, // Apply history visibility checks to all these events and add the ones which pass into the response array,
// honouring the recent_first flag and the limit. // honouring the recent_first flag and the limit.
func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
if rc.hasUnexploredChildren(parentID) {
// we need to do a remote request to pull in the children as we are missing them locally.
serversToQuery := rc.getServersForEventID(parentID)
var result *gomatrixserverlib.MSC2836EventRelationshipsResponse
for _, srv := range serversToQuery {
res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
EventID: parentID,
Direction: "down",
Limit: 100,
MaxBreadth: -1,
MaxDepth: 1, // we just want the children from this parent
RecentFirst: true,
}, rc.roomVersion)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("includeChildren: failed to call MSC2836EventRelationships")
} else {
result = &res
break
}
}
if result != nil {
rc.injectResponseToRoomserver(result)
}
// fallthrough to pull these new events from the DB
}
children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst) children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent") util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent")
@ -313,8 +415,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
} }
var childEvents []*gomatrixserverlib.HeaderedEvent var childEvents []*gomatrixserverlib.HeaderedEvent
for _, child := range children { for _, child := range children {
// in order for us to even know about the children the server must be joined to those rooms, hence pass no claimed room ID or servers. childEvent := rc.lookForEvent(child.EventID)
childEvent := rc.getEventIfVisible(child.EventID, "", nil)
if childEvent != nil { if childEvent != nil {
childEvents = append(childEvents, childEvent) childEvents = append(childEvents, childEvent)
} }
@ -327,14 +428,9 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
// Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag, // Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag,
// honouring the limit, max_depth and max_breadth values according to the following rules // honouring the limit, max_depth and max_breadth values according to the following rules
// nolint: unparam
func walkThread( func walkThread(
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
) ([]*gomatrixserverlib.HeaderedEvent, bool) { ) ([]*gomatrixserverlib.HeaderedEvent, bool) {
if rc.req.Direction != "down" {
util.GetLogger(ctx).Error("not implemented: direction=up")
return nil, false
}
var result []*gomatrixserverlib.HeaderedEvent var result []*gomatrixserverlib.HeaderedEvent
eventWalker := walker{ eventWalker := walker{
ctx: ctx, ctx: ctx,
@ -352,8 +448,11 @@ func walkThread(
} }
// Process the event. // Process the event.
// TODO: Include edge information: room ID and servers // if event is not found, use remoteEventRelationships to explore that part of the thread remotely.
event := rc.getEventIfVisible(wi.EventID, "", nil) // This will probably be easiest if the event relationships response is directly pumped into the database
// so the next walk will do the right thing. This requires those events to be authed and likely injected as
// outliers into the roomserver DB, which will de-dupe appropriately.
event := rc.lookForEvent(wi.EventID)
if event != nil { if event != nil {
result = append(result, event) result = append(result, event)
} }
@ -368,74 +467,280 @@ func walkThread(
return result, limited return result, limited
} }
func (rc *reqCtx) getEventIfVisible(eventID string, claimedRoomID string, claimedServers []string) *gomatrixserverlib.HeaderedEvent { // MSC2836EventRelationships performs an /event_relationships request to a remote server
event, joinedToRoom := getEventIfVisible(rc.ctx, rc.rsAPI, eventID, rc.userID) func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
if event != nil && joinedToRoom { res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
return event EventID: eventID,
DepthFirst: rc.req.DepthFirst,
Direction: rc.req.Direction,
Limit: rc.req.Limit,
MaxBreadth: rc.req.MaxBreadth,
MaxDepth: rc.req.MaxDepth,
RecentFirst: rc.req.RecentFirst,
}, ver)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("Failed to call MSC2836EventRelationships")
return nil, err
} }
// either we don't have the event or we aren't joined to the room, regardless we should try joining if auto join is enabled return &res, nil
if !rc.req.AutoJoin {
return nil
}
// if we're doing this on behalf of a random server don't auto-join rooms regardless of what the request says
if rc.isFederatedRequest {
return nil
}
roomID := claimedRoomID
var servers []gomatrixserverlib.ServerName
if event != nil {
roomID = event.RoomID()
}
for _, s := range claimedServers {
servers = append(servers, gomatrixserverlib.ServerName(s))
}
var joinRes roomserver.PerformJoinResponse
rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{
UserID: rc.userID,
Content: map[string]interface{}{},
RoomIDOrAlias: roomID,
ServerNames: servers,
}, &joinRes)
if joinRes.Error != nil {
util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", roomID).Error("Failed to auto-join room")
return nil
}
if event != nil {
return event
}
// TODO: hit /event_relationships on the server we joined via
util.GetLogger(rc.ctx).Infof("joined room but need to fetch event TODO")
return nil
} }
func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) (*gomatrixserverlib.HeaderedEvent, bool) { // authorisedToSeeEvent checks that the user or server is allowed to see this event. Returns true if allowed to
var queryEventsRes roomserver.QueryEventsByIDResponse // see this request. This only needs to be done once per room at present as we just check for joined status.
err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{ func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) bool {
EventIDs: []string{eventID}, if rc.isFederatedRequest {
}, &queryEventsRes) // make sure the server is in this room
var res fs.QueryJoinedHostServerNamesInRoomResponse
err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
RoomID: event.RoomID(),
}, &res)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID") util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryJoinedHostServerNamesInRoom")
return nil, false return false
} }
if len(queryEventsRes.Events) == 0 { for _, srv := range res.ServerNames {
util.GetLogger(ctx).Infof("event does not exist") if srv == rc.serverName {
return nil, false // event does not exist return true
} }
event := queryEventsRes.Events[0] }
return false
}
// make sure the user is in this room
// Allow events if the member is in the room // Allow events if the member is in the room
// TODO: This does not honour history_visibility // TODO: This does not honour history_visibility
// TODO: This does not honour m.room.create content // TODO: This does not honour m.room.create content
var queryMembershipRes roomserver.QueryMembershipForUserResponse var queryMembershipRes roomserver.QueryMembershipForUserResponse
err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{ err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: event.RoomID(), RoomID: event.RoomID(),
UserID: userID, UserID: rc.userID,
}, &queryMembershipRes) }, &queryMembershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser") util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser")
return nil, false return false
} }
return event, queryMembershipRes.IsInRoom return queryMembershipRes.IsInRoom
}
func (rc *reqCtx) getServersForEventID(eventID string) []gomatrixserverlib.ServerName {
if rc.req.RoomID == "" {
util.GetLogger(rc.ctx).WithField("event_id", eventID).Error(
"getServersForEventID: event exists in unknown room",
)
return nil
}
if rc.roomVersion == "" {
util.GetLogger(rc.ctx).WithField("event_id", eventID).Errorf(
"getServersForEventID: event exists in %s with unknown room version", rc.req.RoomID,
)
return nil
}
var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
RoomID: rc.req.RoomID,
}, &queryRes)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("getServersForEventID: failed to QueryJoinedHostServerNamesInRoom")
return nil
}
// query up to 5 servers
serversToQuery := queryRes.ServerNames
if len(serversToQuery) > 5 {
serversToQuery = serversToQuery[:5]
}
return serversToQuery
}
func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MSC2836EventRelationshipsResponse {
if rc.isFederatedRequest {
return nil // we don't query remote servers for remote requests
}
serversToQuery := rc.getServersForEventID(eventID)
var res *gomatrixserverlib.MSC2836EventRelationshipsResponse
var err error
for _, srv := range serversToQuery {
res, err = rc.MSC2836EventRelationships(eventID, srv, rc.roomVersion)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("remoteEventRelationships: failed to call MSC2836EventRelationships")
} else {
break
}
}
return res
}
// lookForEvent returns the event for the event ID given, by trying to query remote servers
// if the event ID is unknown via /event_relationships.
func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
event := rc.getLocalEvent(eventID)
if event == nil {
queryRes := rc.remoteEventRelationships(eventID)
if queryRes != nil {
// inject all the events into the roomserver then return the event in question
rc.injectResponseToRoomserver(queryRes)
for _, ev := range queryRes.Events {
if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() {
return ev.Headered(ev.Version())
}
}
}
} else if rc.hasUnexploredChildren(eventID) {
// we have the local event but we may need to do a remote hit anyway if we are exploring the thread and have unknown children.
// If we don't do this then we risk never fetching the children.
queryRes := rc.remoteEventRelationships(eventID)
if queryRes != nil {
rc.injectResponseToRoomserver(queryRes)
err := rc.db.MarkChildrenExplored(context.Background(), eventID)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Warnf("failed to mark children of %s as explored", eventID)
}
}
}
if rc.req.RoomID == event.RoomID() {
return event
}
return nil
}
func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
var queryEventsRes roomserver.QueryEventsByIDResponse
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
EventIDs: []string{eventID},
}, &queryEventsRes)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("getLocalEvent: failed to QueryEventsByID")
return nil
}
if len(queryEventsRes.Events) == 0 {
util.GetLogger(rc.ctx).WithField("event_id", eventID).Infof("getLocalEvent: event does not exist")
return nil // event does not exist
}
return queryEventsRes.Events[0]
}
// injectResponseToRoomserver injects the events
// into the roomserver as KindOutlier, with auth chains.
func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) {
var stateEvents []*gomatrixserverlib.Event
var messageEvents []*gomatrixserverlib.Event
for _, ev := range res.Events {
if ev.StateKey() != nil {
stateEvents = append(stateEvents, ev)
} else {
messageEvents = append(messageEvents, ev)
}
}
respState := gomatrixserverlib.RespState{
AuthEvents: res.AuthChain,
StateEvents: stateEvents,
}
eventsInOrder, err := respState.Events()
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse")
return
}
// everything gets sent as an outlier because auth chain events may be disjoint from the DAG
// as may the threaded events.
var ires []roomserver.InputRoomEvent
for _, outlier := range append(eventsInOrder, messageEvents...) {
ires = append(ires, roomserver.InputRoomEvent{
Kind: roomserver.KindOutlier,
Event: outlier.Headered(outlier.Version()),
AuthEventIDs: outlier.AuthEventIDs(),
})
}
// we've got the data by this point so use a background context
err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver")
}
// update the child count / hash columns for these nodes. We need to do this here because not all events will make it
// through to the KindNewEventPersisted hook because the roomserver will ignore duplicates. Duplicates have meaning though
// as the `unsigned` field may differ (if the number of children changes).
for _, ev := range ires {
err = rc.db.UpdateChildMetadata(context.Background(), ev.Event)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).WithField("event_id", ev.Event.EventID()).Warn("failed to update child metadata for event")
}
}
}
func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) {
count, hash := rc.getChildMetadata(ev.EventID())
if count == 0 {
return
}
err := ev.SetUnsignedField("children_hash", gomatrixserverlib.Base64Bytes(hash))
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children_hash")
}
err = ev.SetUnsignedField("children", map[string]int{
constRelType: count,
})
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children count")
}
}
func (rc *reqCtx) getChildMetadata(eventID string) (count int, hash []byte) {
children, err := rc.db.ChildrenForParent(rc.ctx, eventID, constRelType, false)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Warn("Failed to get ChildrenForParent for getting child metadata")
return
}
if len(children) == 0 {
return
}
// sort it lexiographically
sort.Slice(children, func(i, j int) bool {
return children[i].EventID < children[j].EventID
})
// hash it
var eventIDs strings.Builder
for _, c := range children {
_, _ = eventIDs.WriteString(c.EventID)
}
hashValBytes := sha256.Sum256([]byte(eventIDs.String()))
count = len(children)
hash = hashValBytes[:]
return
}
// hasUnexploredChildren returns true if this event has unexplored children.
// "An event has unexplored children if the `unsigned` child count on the parent does not match
// how many children the server believes the parent to have. In addition, if the counts match but
// the hashes do not match, then the event is unexplored."
func (rc *reqCtx) hasUnexploredChildren(eventID string) bool {
if rc.isFederatedRequest {
return false // we only explore children for clients, not servers.
}
// extract largest child count from event
eventCount, eventHash, explored, err := rc.db.ChildMetadata(rc.ctx, eventID)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).WithField("event_id", eventID).Warn(
"failed to get ChildMetadata from db",
)
return false
}
// if there are no recorded children then we know we have >= children.
// if the event has already been explored (read: we hit /event_relationships successfully)
// then don't do it again. We'll only re-do this if we get an even bigger children count,
// see Database.UpdateChildMetadata
if eventCount == 0 || explored {
return false // short-circuit
}
// calculate child count for event
calcCount, calcHash := rc.getChildMetadata(eventID)
if eventCount < calcCount {
return false // we have more children
} else if eventCount > calcCount {
return true // the event has more children than we know about
}
// we have the same count, so a mismatched hash means some children are different
return !bytes.Equal(eventHash, calcHash)
} }
type walkInfo struct { type walkInfo struct {
@ -453,9 +758,9 @@ type walker struct {
// WalkFrom the event ID given // WalkFrom the event ID given
func (w *walker) WalkFrom(eventID string) (limited bool, err error) { func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst) children, err := w.childrenForParent(eventID)
if err != nil { if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk")
return false, err return false, err
} }
var next *walkInfo var next *walkInfo
@ -467,9 +772,9 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
return true, nil return true, nil
} }
// find the children's children // find the children's children
children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst) children, err = w.childrenForParent(next.EventID)
if err != nil { if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk")
return false, err return false, err
} }
toWalk = w.addChildren(toWalk, children, next.Depth+1) toWalk = w.addChildren(toWalk, children, next.Depth+1)
@ -528,3 +833,20 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) {
child, toWalk = toWalk[0], toWalk[1:] child, toWalk = toWalk[0], toWalk[1:]
return &child, toWalk return &child, toWalk
} }
// childrenForParent returns the children events for this event ID, honouring the direction: up|down flags
// meaning this can actually be returning the parent for the event instead of the children.
func (w *walker) childrenForParent(eventID string) ([]eventInfo, error) {
if w.req.Direction == "down" {
return w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
}
// find the event to pull out the parent
ei, err := w.db.ParentForChild(w.ctx, eventID, constRelType)
if err != nil {
return nil, err
}
if ei != nil {
return []eventInfo{*ei}, nil
}
return nil, nil
}

View file

@ -4,10 +4,14 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"sort"
"strings"
"testing" "testing"
"time" "time"
@ -43,9 +47,7 @@ func TestMSC2836(t *testing.T) {
alice := "@alice:localhost" alice := "@alice:localhost"
bob := "@bob:localhost" bob := "@bob:localhost"
charlie := "@charlie:localhost" charlie := "@charlie:localhost"
roomIDA := "!alice:localhost" roomID := "!alice:localhost"
roomIDB := "!bob:localhost"
roomIDC := "!charlie:localhost"
// give access tokens to all three users // give access tokens to all three users
nopUserAPI := &testUserAPI{ nopUserAPI := &testUserAPI{
accessTokens: make(map[string]userapi.Device), accessTokens: make(map[string]userapi.Device),
@ -66,7 +68,7 @@ func TestMSC2836(t *testing.T) {
UserID: charlie, UserID: charlie,
} }
eventA := mustCreateEvent(t, fledglingEvent{ eventA := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDA, RoomID: roomID,
Sender: alice, Sender: alice,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -74,7 +76,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventB := mustCreateEvent(t, fledglingEvent{ eventB := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDB, RoomID: roomID,
Sender: bob, Sender: bob,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -86,7 +88,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventC := mustCreateEvent(t, fledglingEvent{ eventC := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDB, RoomID: roomID,
Sender: bob, Sender: bob,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -98,7 +100,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventD := mustCreateEvent(t, fledglingEvent{ eventD := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDA, RoomID: roomID,
Sender: alice, Sender: alice,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -110,7 +112,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventE := mustCreateEvent(t, fledglingEvent{ eventE := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDB, RoomID: roomID,
Sender: bob, Sender: bob,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -122,7 +124,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventF := mustCreateEvent(t, fledglingEvent{ eventF := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDC, RoomID: roomID,
Sender: charlie, Sender: charlie,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -134,7 +136,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventG := mustCreateEvent(t, fledglingEvent{ eventG := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDA, RoomID: roomID,
Sender: alice, Sender: alice,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -146,7 +148,7 @@ func TestMSC2836(t *testing.T) {
}, },
}) })
eventH := mustCreateEvent(t, fledglingEvent{ eventH := mustCreateEvent(t, fledglingEvent{
RoomID: roomIDB, RoomID: roomID,
Sender: bob, Sender: bob,
Type: "m.room.message", Type: "m.room.message",
Content: map[string]interface{}{ Content: map[string]interface{}{
@ -160,9 +162,9 @@ func TestMSC2836(t *testing.T) {
// make everyone joined to each other's rooms // make everyone joined to each other's rooms
nopRsAPI := &testRoomserverAPI{ nopRsAPI := &testRoomserverAPI{
userToJoinedRooms: map[string][]string{ userToJoinedRooms: map[string][]string{
alice: []string{roomIDA, roomIDB, roomIDC}, alice: []string{roomID},
bob: []string{roomIDA, roomIDB, roomIDC}, bob: []string{roomID},
charlie: []string{roomIDA, roomIDB, roomIDC}, charlie: []string{roomID},
}, },
events: map[string]*gomatrixserverlib.HeaderedEvent{ events: map[string]*gomatrixserverlib.HeaderedEvent{
eventA.EventID(): eventA, eventA.EventID(): eventA,
@ -198,21 +200,6 @@ func TestMSC2836(t *testing.T) {
"include_parent": true, "include_parent": true,
})) }))
}) })
t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) {
nopUserAPI.accessTokens["frank2"] = userapi.Device{
AccessToken: "frank2",
DisplayName: "Frank2 Not In Room",
UserID: "@frank2:localhost",
}
// Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB
nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB}
body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"limit": 1,
"include_parent": true,
}))
assertContains(t, body, []string{eventB.EventID()})
})
t.Run("returns the parent if include_parent is true", func(t *testing.T) { t.Run("returns the parent if include_parent is true", func(t *testing.T) {
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(), "event_id": eventB.EventID(),
@ -349,6 +336,39 @@ func TestMSC2836(t *testing.T) {
})) }))
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
}) })
t.Run("can navigate up the graph with direction: up", func(t *testing.T) {
// A4
// |
// B3
// / \
// C D2
// /| \
// E F1 G
// |
// H
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventF.EventID(),
"recent_first": false,
"depth_first": true,
"direction": "up",
}))
assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()})
})
t.Run("includes children and children_hash in unsigned", func(t *testing.T) {
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": false,
"depth_first": false,
"limit": 3,
}))
// event B has C,D as children
// event C has no children
// event D has 3 children (not included in response)
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID()})
assertUnsignedChildren(t, body.Events[0], "m.reference", 2, []string{eventC.EventID(), eventD.EventID()})
assertUnsignedChildren(t, body.Events[1], "", 0, nil)
assertUnsignedChildren(t, body.Events[2], "m.reference", 3, []string{eventE.EventID(), eventF.EventID(), eventG.EventID()})
})
} }
// TODO: TestMSC2836TerminatesLoops (short and long) // TODO: TestMSC2836TerminatesLoops (short and long)
@ -411,8 +431,12 @@ func postRelationships(t *testing.T, expectCode int, accessToken string, req *ms
} }
if res.StatusCode == 200 { if res.StatusCode == 200 {
var result msc2836.EventRelationshipResponse var result msc2836.EventRelationshipResponse
if err := json.NewDecoder(res.Body).Decode(&result); err != nil { body, err := ioutil.ReadAll(res.Body)
t.Fatalf("response 200 OK but failed to deserialise JSON : %s", err) if err != nil {
t.Fatalf("response 200 OK but failed to read response body: %s", err)
}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
} }
return &result return &result
} }
@ -435,6 +459,43 @@ func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wan
} }
} }
func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relType string, wantCount int, childrenEventIDs []string) {
t.Helper()
unsigned := struct {
Children map[string]int `json:"children"`
Hash string `json:"children_hash"`
}{}
if err := json.Unmarshal(ev.Unsigned, &unsigned); err != nil {
if wantCount == 0 {
return // no children so possible there is no unsigned field at all
}
t.Fatalf("Failed to unmarshal unsigned field: %s", err)
}
// zero checks
if wantCount == 0 {
if len(unsigned.Children) != 0 || unsigned.Hash != "" {
t.Fatalf("want 0 children but got unsigned fields %+v", unsigned)
}
return
}
gotCount := unsigned.Children[relType]
if gotCount != wantCount {
t.Errorf("Got %d count, want %d count for rel_type %s", gotCount, wantCount, relType)
}
// work out the hash
sort.Strings(childrenEventIDs)
var b strings.Builder
for _, s := range childrenEventIDs {
b.WriteString(s)
}
t.Logf("hashing %s", b.String())
hashValBytes := sha256.Sum256([]byte(b.String()))
wantHash := base64.RawStdEncoding.EncodeToString(hashValBytes[:])
if wantHash != unsigned.Hash {
t.Errorf("Got unsigned hash %s want hash %s", unsigned.Hash, wantHash)
}
}
type testUserAPI struct { type testUserAPI struct {
accessTokens map[string]userapi.Device accessTokens map[string]userapi.Device
} }

View file

@ -1,20 +1,22 @@
package msc2836 package msc2836
import ( import (
"bytes"
"context" "context"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
type eventInfo struct { type eventInfo struct {
EventID string EventID string
OriginServerTS gomatrixserverlib.Timestamp OriginServerTS gomatrixserverlib.Timestamp
RoomID string RoomID string
Servers []string
} }
type Database interface { type Database interface {
@ -25,6 +27,21 @@ type Database interface {
// provided `relType`. The returned slice is sorted by origin_server_ts according to whether // provided `relType`. The returned slice is sorted by origin_server_ts according to whether
// `recentFirst` is true or false. // `recentFirst` is true or false.
ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
// ParentForChild returns the parent event for the given child `eventID`. The eventInfo should be nil if
// there is no parent for this child event, with no error. The parent eventInfo can be missing the
// timestamp if the event is not known to the server.
ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error)
// UpdateChildMetadata persists the children_count and children_hash from this event if and only if
// the count is greater than what was previously there. If the count is updated, the event will be
// updated to be unexplored.
UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error
// ChildMetadata returns the children_count and children_hash for the event ID in question.
// Also returns the `explored` flag, which is set to true when MarkChildrenExplored is called and is set
// back to `false` when a larger count is inserted via UpdateChildMetadata.
// Returns nil error if the event ID does not exist.
ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error)
// MarkChildrenExplored sets the 'explored' flag on this event to `true`.
MarkChildrenExplored(ctx context.Context, eventID string) error
} }
type DB struct { type DB struct {
@ -34,6 +51,10 @@ type DB struct {
insertNodeStmt *sql.Stmt insertNodeStmt *sql.Stmt
selectChildrenForParentOldestFirstStmt *sql.Stmt selectChildrenForParentOldestFirstStmt *sql.Stmt
selectChildrenForParentRecentFirstStmt *sql.Stmt selectChildrenForParentRecentFirstStmt *sql.Stmt
selectParentForChildStmt *sql.Stmt
updateChildMetadataStmt *sql.Stmt
selectChildMetadataStmt *sql.Stmt
updateChildMetadataExploredStmt *sql.Stmt
} }
// NewDatabase loads the database for msc2836 // NewDatabase loads the database for msc2836
@ -65,19 +86,26 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
CREATE TABLE IF NOT EXISTS msc2836_nodes ( CREATE TABLE IF NOT EXISTS msc2836_nodes (
event_id TEXT PRIMARY KEY NOT NULL, event_id TEXT PRIMARY KEY NOT NULL,
origin_server_ts BIGINT NOT NULL, origin_server_ts BIGINT NOT NULL,
room_id TEXT NOT NULL room_id TEXT NOT NULL,
unsigned_children_count BIGINT NOT NULL,
unsigned_children_hash TEXT NOT NULL,
explored SMALLINT NOT NULL
); );
`) `)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if d.insertEdgeStmt, err = d.db.Prepare(` if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
if d.insertNodeStmt, err = d.db.Prepare(` if d.insertNodeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
VALUES($1, $2, $3, $4, $5, $6)
ON CONFLICT DO NOTHING
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
@ -93,6 +121,27 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err return nil, err
} }
if d.selectParentForChildStmt, err = d.db.Prepare(`
SELECT parent_event_id, parent_room_id FROM msc2836_edges
WHERE child_event_id = $1 AND rel_type = $2
`); err != nil {
return nil, err
}
if d.updateChildMetadataStmt, err = d.db.Prepare(`
UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
`); err != nil {
return nil, err
}
if d.selectChildMetadataStmt, err = d.db.Prepare(`
SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
`); err != nil {
return nil, err
}
if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
`); err != nil {
return nil, err
}
return &d, err return &d, err
} }
@ -117,19 +166,26 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
CREATE TABLE IF NOT EXISTS msc2836_nodes ( CREATE TABLE IF NOT EXISTS msc2836_nodes (
event_id TEXT PRIMARY KEY NOT NULL, event_id TEXT PRIMARY KEY NOT NULL,
origin_server_ts BIGINT NOT NULL, origin_server_ts BIGINT NOT NULL,
room_id TEXT NOT NULL room_id TEXT NOT NULL,
unsigned_children_count BIGINT NOT NULL,
unsigned_children_hash TEXT NOT NULL,
explored SMALLINT NOT NULL
); );
`) `)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if d.insertEdgeStmt, err = d.db.Prepare(` if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
if d.insertNodeStmt, err = d.db.Prepare(` if d.insertNodeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
VALUES($1, $2, $3, $4, $5, $6)
ON CONFLICT DO NOTHING
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
@ -145,6 +201,27 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err return nil, err
} }
if d.selectParentForChildStmt, err = d.db.Prepare(`
SELECT parent_event_id, parent_room_id FROM msc2836_edges
WHERE child_event_id = $1 AND rel_type = $2
`); err != nil {
return nil, err
}
if d.updateChildMetadataStmt, err = d.db.Prepare(`
UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
`); err != nil {
return nil, err
}
if d.selectChildMetadataStmt, err = d.db.Prepare(`
SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
`); err != nil {
return nil, err
}
if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
`); err != nil {
return nil, err
}
return &d, nil return &d, nil
} }
@ -158,16 +235,55 @@ func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEv
if err != nil { if err != nil {
return err return err
} }
count, hash := extractChildMetadata(ev)
return p.writer.Do(p.db, nil, func(txn *sql.Tx) error { return p.writer.Do(p.db, nil, func(txn *sql.Tx) error {
_, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON)) _, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON))
if err != nil { if err != nil {
return err return err
} }
_, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID()) util.GetLogger(ctx).Infof("StoreRelation child=%s parent=%s rel_type=%s", child, parent, relType)
_, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID(), count, base64.RawStdEncoding.EncodeToString(hash), 0)
return err return err
}) })
} }
func (p *DB) UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error {
eventCount, eventHash := extractChildMetadata(ev)
if eventCount == 0 {
return nil // nothing to update with
}
// extract current children count/hash, if they are less than the current event then update the columns and set to unexplored
count, hash, _, err := p.ChildMetadata(ctx, ev.EventID())
if err != nil {
return err
}
if eventCount > count || (eventCount == count && !bytes.Equal(hash, eventHash)) {
_, err = p.updateChildMetadataStmt.ExecContext(ctx, eventCount, base64.RawStdEncoding.EncodeToString(eventHash), 0, ev.EventID())
return err
}
return nil
}
func (p *DB) ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error) {
var b64hash string
var exploredInt int
if err = p.selectChildMetadataStmt.QueryRowContext(ctx, eventID).Scan(&count, &b64hash, &exploredInt); err != nil {
if err == sql.ErrNoRows {
err = nil
}
return
}
hash, err = base64.RawStdEncoding.DecodeString(b64hash)
explored = exploredInt > 0
return
}
func (p *DB) MarkChildrenExplored(ctx context.Context, eventID string) error {
_, err := p.updateChildMetadataExploredStmt.ExecContext(ctx, 1, eventID)
return err
}
func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) { func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) {
var rows *sql.Rows var rows *sql.Rows
var err error var err error
@ -191,6 +307,17 @@ func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, rec
return children, nil return children, nil
} }
func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) {
var ei eventInfo
err := p.selectParentForChildStmt.QueryRowContext(ctx, eventID, relType).Scan(&ei.EventID, &ei.RoomID)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
return &ei, nil
}
func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) { func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) {
if ev == nil { if ev == nil {
return return
@ -224,3 +351,19 @@ func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, serve
} }
return body.RoomID, body.Servers return body.RoomID, body.Servers
} }
func extractChildMetadata(ev *gomatrixserverlib.HeaderedEvent) (count int, hash []byte) {
unsigned := struct {
Counts map[string]int `json:"children"`
Hash gomatrixserverlib.Base64Bytes `json:"children_hash"`
}{}
if err := json.Unmarshal(ev.Unsigned(), &unsigned); err != nil {
// expected if there is no unsigned field at all
return
}
for _, c := range unsigned.Counts {
count += c
}
hash = unsigned.Hash
return
}

View file

@ -16,15 +16,18 @@
package mscs package mscs
import ( import (
"context"
"fmt" "fmt"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/mscs/msc2836" "github.com/matrix-org/dendrite/setup/mscs/msc2836"
"github.com/matrix-org/util"
) )
// Enable MSCs - returns an error on unknown MSCs // Enable MSCs - returns an error on unknown MSCs
func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error { func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error {
for _, msc := range base.Cfg.MSCs.MSCs { for _, msc := range base.Cfg.MSCs.MSCs {
util.GetLogger(context.Background()).WithField("msc", msc).Info("Enabling MSC")
if err := EnableMSC(base, monolith, msc); err != nil { if err := EnableMSC(base, monolith, msc); err != nil {
return err return err
} }

View file

@ -76,7 +76,7 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" + " AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
@ -92,10 +92,10 @@ const selectStateEventSQL = "" +
const selectEventsWithEventIDsSQL = "" + const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because otherwise // TODO: The session_id and transaction_id blanks are here because otherwise
// the rowsToStreamEvents expects there to be exactly five columns. We need to // the rowsToStreamEvents expects there to be exactly six columns. We need to
// figure out if these really need to be in the DB, and if so, we need a // figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020 // better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)" " FROM syncapi_current_room_state WHERE event_id = ANY($1)"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
@ -278,13 +278,14 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
result := []*gomatrixserverlib.HeaderedEvent{} result := []*gomatrixserverlib.HeaderedEvent{}
for rows.Next() { for rows.Next() {
var eventID string
var eventBytes []byte var eventBytes []byte
if err := rows.Scan(&eventBytes); err != nil { if err := rows.Scan(&eventID, &eventBytes); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent var ev gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventBytes, &ev); err != nil { if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err return nil, err
} }
result = append(result, &ev) result = append(result, &ev)

View file

@ -79,20 +79,20 @@ const insertEventSQL = "" +
"RETURNING id" "RETURNING id"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4" " ORDER BY id ASC LIMIT $4"
@ -413,6 +413,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {
var ( var (
eventID string
streamPos types.StreamPosition streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
@ -420,12 +421,12 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
txnID *string txnID *string
transactionID *api.TransactionID transactionID *api.TransactionID
) )
if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent var ev gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventBytes, &ev); err != nil { if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err return nil, err
} }

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -514,25 +515,28 @@ func (d *Database) addPDUDeltaToResponse(
deltas, joinedRoomIDs, err = d.getStateDeltas( deltas, joinedRoomIDs, err = d.getStateDeltas(
ctx, &device, txn, r, device.UserID, &stateFilter, ctx, &device, txn, r, device.UserID, &stateFilter,
) )
if err != nil {
return nil, fmt.Errorf("d.getStateDeltas: %w", err)
}
} else { } else {
deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync( deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync(
ctx, &device, txn, r, device.UserID, &stateFilter, ctx, &device, txn, r, device.UserID, &stateFilter,
) )
}
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("d.getStateDeltasForFullStateSync: %w", err)
}
} }
for _, delta := range deltas { for _, delta := range deltas {
err = d.addRoomDeltaToResponse(ctx, &device, txn, r, delta, numRecentEventsPerRoom, res) err = d.addRoomDeltaToResponse(ctx, &device, txn, r, delta, numRecentEventsPerRoom, res)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("d.addRoomDeltaToResponse: %w", err)
} }
} }
// TODO: This should be done in getStateDeltas // TODO: This should be done in getStateDeltas
if err = d.addInvitesToResponse(ctx, txn, device.UserID, r, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, device.UserID, r, res); err != nil {
return nil, err return nil, fmt.Errorf("d.addInvitesToResponse: %w", err)
} }
succeeded = true succeeded = true
@ -1126,7 +1130,13 @@ func (d *Database) fetchMissingStateEvents(
return nil, err return nil, err
} }
if len(stateEvents) != len(missing) { if len(stateEvents) != len(missing) {
return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing)) logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing))
// TODO: Why is this happening? It's probably the roomserver. Uncomment
// this error again when we work out what it is and fix it, otherwise we
// just end up returning lots of 500s to the client and that breaks
// pretty much everything, rather than just sending what we have.
//return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing))
} }
events = append(events, stateEvents...) events = append(events, stateEvents...)
return events, nil return events, nil

View file

@ -64,7 +64,7 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +
" AND ( $2 IS NULL OR sender IN ($2) )" + " AND ( $2 IS NULL OR sender IN ($2) )" +
" AND ( $3 IS NULL OR NOT(sender IN ($3)) )" + " AND ( $3 IS NULL OR NOT(sender IN ($3)) )" +
" AND ( $4 IS NULL OR type IN ($4) )" + " AND ( $4 IS NULL OR type IN ($4) )" +
@ -80,10 +80,10 @@ const selectStateEventSQL = "" +
const selectEventsWithEventIDsSQL = "" + const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because otherwise // TODO: The session_id and transaction_id blanks are here because otherwise
// the rowsToStreamEvents expects there to be exactly five columns. We need to // the rowsToStreamEvents expects there to be exactly six columns. We need to
// figure out if these really need to be in the DB, and if so, we need a // figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020 // better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)" " FROM syncapi_current_room_state WHERE event_id IN ($1)"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
@ -289,13 +289,14 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
result := []*gomatrixserverlib.HeaderedEvent{} result := []*gomatrixserverlib.HeaderedEvent{}
for rows.Next() { for rows.Next() {
var eventID string
var eventBytes []byte var eventBytes []byte
if err := rows.Scan(&eventBytes); err != nil { if err := rows.Scan(&eventID, &eventBytes); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent var ev gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventBytes, &ev); err != nil { if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err return nil, err
} }
result = append(result, &ev) result = append(result, &ev)

View file

@ -56,20 +56,20 @@ const insertEventSQL = "" +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = $13" "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = $13"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4" " ORDER BY id ASC LIMIT $4"
@ -428,6 +428,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {
var ( var (
eventID string
streamPos types.StreamPosition streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
@ -435,12 +436,12 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
txnID *string txnID *string
transactionID *api.TransactionID transactionID *api.TransactionID
) )
if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent var ev gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventBytes, &ev); err != nil { if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err return nil, err
} }

View file

@ -64,3 +64,6 @@ A prev_batch token from incremental sync can be used in the v1 messages API
# Blacklisted due to flakiness # Blacklisted due to flakiness
Forgotten room messages cannot be paginated Forgotten room messages cannot be paginated
# Blacklisted due to flakiness
Can re-join room if re-invited

View file

@ -501,6 +501,5 @@ Can get rooms/{roomId}/state for a departed room (SPEC-216)
Users cannot set notifications powerlevel higher than their own Users cannot set notifications powerlevel higher than their own
Forgetting room does not show up in v2 /sync Forgetting room does not show up in v2 /sync
Can forget room you've been kicked from Can forget room you've been kicked from
Can re-join room if re-invited
/whois /whois
/joined_members return joined members /joined_members return joined members