mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-03-03 17:03:10 -06:00
Merge branch 'main' into patch-1
This commit is contained in:
commit
0624d4a643
3
.github/workflows/dendrite.yml
vendored
3
.github/workflows/dendrite.yml
vendored
|
|
@ -7,6 +7,7 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
release:
|
release:
|
||||||
types: [published]
|
types: [published]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
|
@ -375,6 +376,8 @@ jobs:
|
||||||
# Build initial Dendrite image
|
# Build initial Dendrite image
|
||||||
- run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile .
|
- run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile .
|
||||||
working-directory: dendrite
|
working-directory: dendrite
|
||||||
|
env:
|
||||||
|
DOCKER_BUILDKIT: 1
|
||||||
|
|
||||||
# Run Complement
|
# Run Complement
|
||||||
- run: |
|
- run: |
|
||||||
|
|
|
||||||
25
CHANGES.md
25
CHANGES.md
|
|
@ -1,5 +1,30 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## Dendrite 0.9.5 (2022-08-25)
|
||||||
|
|
||||||
|
### Fixes
|
||||||
|
|
||||||
|
* The roomserver will now correctly unreject previously rejected events if necessary when reprocessing
|
||||||
|
* The handling of event soft-failure has been improved on the roomserver input by no longer applying rejection rules and still calculating state before the event if possible
|
||||||
|
* The federation `/state` and `/state_ids` endpoints should now return the correct error code when the state isn't known instead of returning a HTTP 500
|
||||||
|
* The federation `/event` should now return outlier events correctly instead of returning a HTTP 500
|
||||||
|
* A bug in the federation backoff allowing zero intervals has been corrected
|
||||||
|
* The `create-account` utility will no longer error if the homeserver URL ends in a trailing slash
|
||||||
|
* A regression in `/sync` introduced in 0.9.4 should be fixed
|
||||||
|
|
||||||
|
## Dendrite 0.9.4 (2022-08-19)
|
||||||
|
|
||||||
|
### Fixes
|
||||||
|
|
||||||
|
* A bug in the roomserver around handling rejected outliers has been fixed
|
||||||
|
* Backfilled events will now use the correct history visibility where possible
|
||||||
|
* The device list updater backoff has been fixed, which should reduce the number of outbound HTTP requests and `Failed to query device keys for some users` log entries for dead servers
|
||||||
|
* The `/sync` endpoint will no longer incorrectly return room entries for retired invites which could cause some rooms to show up in the client "Historical" section
|
||||||
|
* The `/createRoom` endpoint will now correctly populate `is_direct` in invite membership events, which may help clients to classify direct messages correctly
|
||||||
|
* The `create-account` tool will now log an error if the shared secret is not set in the Dendrite config
|
||||||
|
* A couple of minor bugs have been fixed in the membership lazy-loading
|
||||||
|
* Queued EDUs in the federation API are now cached properly
|
||||||
|
|
||||||
## Dendrite 0.9.3 (2022-08-15)
|
## Dendrite 0.9.3 (2022-08-15)
|
||||||
|
|
||||||
### Important
|
### Important
|
||||||
|
|
|
||||||
|
|
@ -1,10 +0,0 @@
|
||||||
# Application Service
|
|
||||||
|
|
||||||
This component interfaces with external [Application
|
|
||||||
Services](https://matrix.org/docs/spec/application_service/unstable.html).
|
|
||||||
This includes any HTTP endpoints that application services call, as well as talking
|
|
||||||
to any HTTP endpoints that application services provide themselves.
|
|
||||||
|
|
||||||
## Consumers
|
|
||||||
|
|
||||||
This component consumes and filters events from the Roomserver Kafka stream, passing on any necessary events to subscribing application services.
|
|
||||||
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
@ -28,9 +27,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/appservice/consumers"
|
"github.com/matrix-org/dendrite/appservice/consumers"
|
||||||
"github.com/matrix-org/dendrite/appservice/inthttp"
|
"github.com/matrix-org/dendrite/appservice/inthttp"
|
||||||
"github.com/matrix-org/dendrite/appservice/query"
|
"github.com/matrix-org/dendrite/appservice/query"
|
||||||
"github.com/matrix-org/dendrite/appservice/storage"
|
|
||||||
"github.com/matrix-org/dendrite/appservice/types"
|
|
||||||
"github.com/matrix-org/dendrite/appservice/workers"
|
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
|
@ -59,57 +55,40 @@ func NewInternalAPI(
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
// Create appserivce query API with an HTTP client that will be used for all
|
||||||
|
// outbound and inbound requests (inbound only for the internal API)
|
||||||
|
appserviceQueryAPI := &query.AppServiceQueryAPI{
|
||||||
|
HTTPClient: client,
|
||||||
|
Cfg: &base.Cfg.AppServiceAPI,
|
||||||
|
}
|
||||||
|
|
||||||
// Create a connection to the appservice postgres DB
|
if len(base.Cfg.Derived.ApplicationServices) == 0 {
|
||||||
appserviceDB, err := storage.NewDatabase(base, &base.Cfg.AppServiceAPI.Database)
|
return appserviceQueryAPI
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Panicf("failed to connect to appservice db")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap application services in a type that relates the application service and
|
// Wrap application services in a type that relates the application service and
|
||||||
// a sync.Cond object that can be used to notify workers when there are new
|
// a sync.Cond object that can be used to notify workers when there are new
|
||||||
// events to be sent out.
|
// events to be sent out.
|
||||||
workerStates := make([]types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices))
|
for _, appservice := range base.Cfg.Derived.ApplicationServices {
|
||||||
for i, appservice := range base.Cfg.Derived.ApplicationServices {
|
|
||||||
m := sync.Mutex{}
|
|
||||||
ws := types.ApplicationServiceWorkerState{
|
|
||||||
AppService: appservice,
|
|
||||||
Cond: sync.NewCond(&m),
|
|
||||||
}
|
|
||||||
workerStates[i] = ws
|
|
||||||
|
|
||||||
// Create bot account for this AS if it doesn't already exist
|
// Create bot account for this AS if it doesn't already exist
|
||||||
if err = generateAppServiceAccount(userAPI, appservice); err != nil {
|
if err := generateAppServiceAccount(userAPI, appservice); err != nil {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"appservice": appservice.ID,
|
"appservice": appservice.ID,
|
||||||
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create appserivce query API with an HTTP client that will be used for all
|
|
||||||
// outbound and inbound requests (inbound only for the internal API)
|
|
||||||
appserviceQueryAPI := &query.AppServiceQueryAPI{
|
|
||||||
HTTPClient: client,
|
|
||||||
Cfg: base.Cfg,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only consume if we actually have ASes to track, else we'll just chew cycles needlessly.
|
// Only consume if we actually have ASes to track, else we'll just chew cycles needlessly.
|
||||||
// We can't add ASes at runtime so this is safe to do.
|
// We can't add ASes at runtime so this is safe to do.
|
||||||
if len(workerStates) > 0 {
|
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||||
consumer := consumers.NewOutputRoomEventConsumer(
|
consumer := consumers.NewOutputRoomEventConsumer(
|
||||||
base.ProcessContext, base.Cfg, js, appserviceDB,
|
base.ProcessContext, &base.Cfg.AppServiceAPI,
|
||||||
rsAPI, workerStates,
|
client, js, rsAPI,
|
||||||
)
|
)
|
||||||
if err := consumer.Start(); err != nil {
|
if err := consumer.Start(); err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to start appservice roomserver consumer")
|
logrus.WithError(err).Panicf("failed to start appservice roomserver consumer")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create application service transaction workers
|
|
||||||
if err := workers.SetupTransactionWorkers(client, appserviceDB, workerStates); err != nil {
|
|
||||||
logrus.WithError(err).Panicf("failed to start app service transaction workers")
|
|
||||||
}
|
|
||||||
return appserviceQueryAPI
|
return appserviceQueryAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,193 +15,214 @@
|
||||||
package consumers
|
package consumers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/appservice/storage"
|
|
||||||
"github.com/matrix-org/dendrite/appservice/types"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/nats-io/nats.go"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OutputRoomEventConsumer consumes events that originated in the room server.
|
// OutputRoomEventConsumer consumes events that originated in the room server.
|
||||||
type OutputRoomEventConsumer struct {
|
type OutputRoomEventConsumer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
jetstream nats.JetStreamContext
|
cfg *config.AppServiceAPI
|
||||||
durable string
|
client *http.Client
|
||||||
topic string
|
jetstream nats.JetStreamContext
|
||||||
asDB storage.Database
|
topic string
|
||||||
rsAPI api.AppserviceRoomserverAPI
|
rsAPI api.AppserviceRoomserverAPI
|
||||||
serverName string
|
}
|
||||||
workerStates []types.ApplicationServiceWorkerState
|
|
||||||
|
type appserviceState struct {
|
||||||
|
*config.ApplicationService
|
||||||
|
backoff int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call
|
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call
|
||||||
// Start() to begin consuming from room servers.
|
// Start() to begin consuming from room servers.
|
||||||
func NewOutputRoomEventConsumer(
|
func NewOutputRoomEventConsumer(
|
||||||
process *process.ProcessContext,
|
process *process.ProcessContext,
|
||||||
cfg *config.Dendrite,
|
cfg *config.AppServiceAPI,
|
||||||
|
client *http.Client,
|
||||||
js nats.JetStreamContext,
|
js nats.JetStreamContext,
|
||||||
appserviceDB storage.Database,
|
|
||||||
rsAPI api.AppserviceRoomserverAPI,
|
rsAPI api.AppserviceRoomserverAPI,
|
||||||
workerStates []types.ApplicationServiceWorkerState,
|
|
||||||
) *OutputRoomEventConsumer {
|
) *OutputRoomEventConsumer {
|
||||||
return &OutputRoomEventConsumer{
|
return &OutputRoomEventConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
jetstream: js,
|
cfg: cfg,
|
||||||
durable: cfg.Global.JetStream.Durable("AppserviceRoomserverConsumer"),
|
client: client,
|
||||||
topic: cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent),
|
jetstream: js,
|
||||||
asDB: appserviceDB,
|
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
serverName: string(cfg.Global.ServerName),
|
|
||||||
workerStates: workerStates,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start consuming from room servers
|
// Start consuming from room servers
|
||||||
func (s *OutputRoomEventConsumer) Start() error {
|
func (s *OutputRoomEventConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
for _, as := range s.cfg.Derived.ApplicationServices {
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
appsvc := as
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
state := &appserviceState{
|
||||||
)
|
ApplicationService: &appsvc,
|
||||||
|
}
|
||||||
|
token := jetstream.Tokenise(as.ID)
|
||||||
|
if err := jetstream.JetStreamConsumer(
|
||||||
|
s.ctx, s.jetstream, s.topic,
|
||||||
|
s.cfg.Matrix.JetStream.Durable("Appservice_"+token),
|
||||||
|
50, // maximum number of events to send in a single transaction
|
||||||
|
func(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
return s.onMessage(ctx, state, msgs)
|
||||||
|
},
|
||||||
|
nats.DeliverNew(), nats.ManualAck(),
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("failed to create %q consumer: %w", token, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage is called when the appservice component receives a new event from
|
// onMessage is called when the appservice component receives a new event from
|
||||||
// the room server output log.
|
// the room server output log.
|
||||||
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputRoomEventConsumer) onMessage(
|
||||||
// Parse out the event JSON
|
ctx context.Context, state *appserviceState, msgs []*nats.Msg,
|
||||||
var output api.OutputEvent
|
) bool {
|
||||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
log.WithField("appservice", state.ID).Tracef("Appservice worker received %d message(s) from roomserver", len(msgs))
|
||||||
// If the message was invalid, log it and move on to the next message in the stream
|
events := make([]*gomatrixserverlib.HeaderedEvent, 0, len(msgs))
|
||||||
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
for _, msg := range msgs {
|
||||||
return true
|
// Parse out the event JSON
|
||||||
}
|
var output api.OutputEvent
|
||||||
|
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||||
log.WithFields(log.Fields{
|
// If the message was invalid, log it and move on to the next message in the stream
|
||||||
"type": output.Type,
|
log.WithField("appservice", state.ID).WithError(err).Errorf("Appservice failed to parse message, ignoring")
|
||||||
}).Debug("Got a message in OutputRoomEventConsumer")
|
continue
|
||||||
|
|
||||||
events := []*gomatrixserverlib.HeaderedEvent{}
|
|
||||||
if output.Type == api.OutputTypeNewRoomEvent && output.NewRoomEvent != nil {
|
|
||||||
newEventID := output.NewRoomEvent.Event.EventID()
|
|
||||||
events = append(events, output.NewRoomEvent.Event)
|
|
||||||
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
|
|
||||||
eventsReq := &api.QueryEventsByIDRequest{
|
|
||||||
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
|
|
||||||
}
|
|
||||||
eventsRes := &api.QueryEventsByIDResponse{}
|
|
||||||
for _, eventID := range output.NewRoomEvent.AddsStateEventIDs {
|
|
||||||
if eventID != newEventID {
|
|
||||||
eventsReq.EventIDs = append(eventsReq.EventIDs, eventID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(eventsReq.EventIDs) > 0 {
|
|
||||||
if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
events = append(events, eventsRes.Events...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else if output.Type == api.OutputTypeNewInviteEvent && output.NewInviteEvent != nil {
|
switch output.Type {
|
||||||
events = append(events, output.NewInviteEvent.Event)
|
case api.OutputTypeNewRoomEvent:
|
||||||
} else {
|
if output.NewRoomEvent == nil || !s.appserviceIsInterestedInEvent(ctx, output.NewRoomEvent.Event, state.ApplicationService) {
|
||||||
log.WithFields(log.Fields{
|
continue
|
||||||
"type": output.Type,
|
}
|
||||||
}).Debug("appservice OutputRoomEventConsumer ignoring event", string(msg.Data))
|
events = append(events, output.NewRoomEvent.Event)
|
||||||
|
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
|
||||||
|
newEventID := output.NewRoomEvent.Event.EventID()
|
||||||
|
eventsReq := &api.QueryEventsByIDRequest{
|
||||||
|
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
|
||||||
|
}
|
||||||
|
eventsRes := &api.QueryEventsByIDResponse{}
|
||||||
|
for _, eventID := range output.NewRoomEvent.AddsStateEventIDs {
|
||||||
|
if eventID != newEventID {
|
||||||
|
eventsReq.EventIDs = append(eventsReq.EventIDs, eventID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(eventsReq.EventIDs) > 0 {
|
||||||
|
if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil {
|
||||||
|
log.WithError(err).Errorf("s.rsAPI.QueryEventsByID failed")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
events = append(events, eventsRes.Events...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case api.OutputTypeNewInviteEvent:
|
||||||
|
if output.NewInviteEvent == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
events = append(events, output.NewInviteEvent.Event)
|
||||||
|
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are no events selected for sending then we should
|
||||||
|
// ack the messages so that we don't get sent them again in the
|
||||||
|
// future.
|
||||||
|
if len(events) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send event to any relevant application services
|
// Send event to any relevant application services. If we hit
|
||||||
if err := s.filterRoomserverEvents(context.TODO(), events); err != nil {
|
// an error here, return false, so that we negatively ack.
|
||||||
log.WithError(err).Errorf("roomserver output log: filter error")
|
log.WithField("appservice", state.ID).Debugf("Appservice worker sending %d events(s) from roomserver", len(events))
|
||||||
return true
|
return s.sendEvents(ctx, state, events) == nil
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterRoomserverEvents takes in events and decides whether any of them need
|
// sendEvents passes events to the appservice by using the transactions
|
||||||
// to be passed on to an external application service. It does this by checking
|
// endpoint. It will block for the backoff period if necessary.
|
||||||
// each namespace of each registered application service, and if there is a
|
func (s *OutputRoomEventConsumer) sendEvents(
|
||||||
// match, adds the event to the queue for events to be sent to a particular
|
ctx context.Context, state *appserviceState,
|
||||||
// application service.
|
|
||||||
func (s *OutputRoomEventConsumer) filterRoomserverEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
events []*gomatrixserverlib.HeaderedEvent,
|
events []*gomatrixserverlib.HeaderedEvent,
|
||||||
) error {
|
) error {
|
||||||
for _, ws := range s.workerStates {
|
// Create the transaction body.
|
||||||
for _, event := range events {
|
transaction, err := json.Marshal(
|
||||||
// Check if this event is interesting to this application service
|
gomatrixserverlib.ApplicationServiceTransaction{
|
||||||
if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) {
|
Events: gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll),
|
||||||
// Queue this event to be sent off to the application service
|
},
|
||||||
if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil {
|
)
|
||||||
log.WithError(err).Warn("failed to insert incoming event into appservices database")
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else {
|
|
||||||
// Tell our worker to send out new messages by updating remaining message
|
|
||||||
// count and waking them up with a broadcast
|
|
||||||
ws.NotifyNewEvents()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: We should probably be more intelligent and pick something not
|
||||||
|
// in the control of the event. A NATS timestamp header or something maybe.
|
||||||
|
txnID := events[0].Event.OriginServerTS()
|
||||||
|
|
||||||
|
// Send the transaction to the appservice.
|
||||||
|
// https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid
|
||||||
|
address := fmt.Sprintf("%s/transactions/%d?access_token=%s", state.URL, txnID, url.QueryEscape(state.HSToken))
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "PUT", address, bytes.NewBuffer(transaction))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, err := s.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return state.backoffAndPause(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the response was fine then we can clear any backoffs in place and
|
||||||
|
// report that everything was OK. Otherwise, back off for a while.
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case http.StatusOK:
|
||||||
|
state.backoff = 0
|
||||||
|
default:
|
||||||
|
return state.backoffAndPause(fmt.Errorf("received HTTP status code %d from appservice", resp.StatusCode))
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// appserviceJoinedAtEvent returns a boolean depending on whether a given
|
// backoff pauses the calling goroutine for a 2^some backoff exponent seconds
|
||||||
// appservice has membership at the time a given event was created.
|
func (s *appserviceState) backoffAndPause(err error) error {
|
||||||
func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool {
|
if s.backoff < 6 {
|
||||||
// TODO: This is only checking the current room state, not the state at
|
s.backoff++
|
||||||
// the event in question. Pretty sure this is what Synapse does too, but
|
|
||||||
// until we have a lighter way of checking the state before the event that
|
|
||||||
// doesn't involve state res, then this is probably OK.
|
|
||||||
membershipReq := &api.QueryMembershipsForRoomRequest{
|
|
||||||
RoomID: event.RoomID(),
|
|
||||||
JoinedOnly: true,
|
|
||||||
}
|
}
|
||||||
membershipRes := &api.QueryMembershipsForRoomResponse{}
|
duration := time.Second * time.Duration(math.Pow(2, float64(s.backoff)))
|
||||||
|
log.WithField("appservice", s.ID).WithError(err).Errorf("Unable to send transaction to appservice, backing off for %s", duration.String())
|
||||||
// XXX: This could potentially race if the state for the event is not known yet
|
time.Sleep(duration)
|
||||||
// e.g. the event came over federation but we do not have the full state persisted.
|
return err
|
||||||
if err := s.rsAPI.QueryMembershipsForRoom(ctx, membershipReq, membershipRes); err == nil {
|
|
||||||
for _, ev := range membershipRes.JoinEvents {
|
|
||||||
var membership gomatrixserverlib.MemberContent
|
|
||||||
if err = json.Unmarshal(ev.Content, &membership); err != nil || ev.StateKey == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if appservice.IsInterestedInUserID(*ev.StateKey) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"room_id": event.RoomID(),
|
|
||||||
}).WithError(err).Errorf("Unable to get membership for room")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// appserviceIsInterestedInEvent returns a boolean depending on whether a given
|
// appserviceIsInterestedInEvent returns a boolean depending on whether a given
|
||||||
// event falls within one of a given application service's namespaces.
|
// event falls within one of a given application service's namespaces.
|
||||||
//
|
//
|
||||||
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
|
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
|
||||||
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool {
|
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool {
|
||||||
// No reason to queue events if they'll never be sent to the application
|
switch {
|
||||||
// service
|
case appservice.URL == "":
|
||||||
if appservice.URL == "" {
|
|
||||||
return false
|
return false
|
||||||
}
|
case appservice.IsInterestedInUserID(event.Sender()):
|
||||||
|
return true
|
||||||
// Check Room ID and Sender of the event
|
case appservice.IsInterestedInRoomID(event.RoomID()):
|
||||||
if appservice.IsInterestedInUserID(event.Sender()) ||
|
|
||||||
appservice.IsInterestedInRoomID(event.RoomID()) {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -222,10 +243,52 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"room_id": event.RoomID(),
|
"appservice": appservice.ID,
|
||||||
|
"room_id": event.RoomID(),
|
||||||
}).WithError(err).Errorf("Unable to get aliases for room")
|
}).WithError(err).Errorf("Unable to get aliases for room")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if any of the members in the room match the appservice
|
// Check if any of the members in the room match the appservice
|
||||||
return s.appserviceJoinedAtEvent(ctx, event, appservice)
|
return s.appserviceJoinedAtEvent(ctx, event, appservice)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// appserviceJoinedAtEvent returns a boolean depending on whether a given
|
||||||
|
// appservice has membership at the time a given event was created.
|
||||||
|
func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool {
|
||||||
|
// TODO: This is only checking the current room state, not the state at
|
||||||
|
// the event in question. Pretty sure this is what Synapse does too, but
|
||||||
|
// until we have a lighter way of checking the state before the event that
|
||||||
|
// doesn't involve state res, then this is probably OK.
|
||||||
|
membershipReq := &api.QueryMembershipsForRoomRequest{
|
||||||
|
RoomID: event.RoomID(),
|
||||||
|
JoinedOnly: true,
|
||||||
|
}
|
||||||
|
membershipRes := &api.QueryMembershipsForRoomResponse{}
|
||||||
|
|
||||||
|
// XXX: This could potentially race if the state for the event is not known yet
|
||||||
|
// e.g. the event came over federation but we do not have the full state persisted.
|
||||||
|
if err := s.rsAPI.QueryMembershipsForRoom(ctx, membershipReq, membershipRes); err == nil {
|
||||||
|
for _, ev := range membershipRes.JoinEvents {
|
||||||
|
switch {
|
||||||
|
case ev.StateKey == nil:
|
||||||
|
continue
|
||||||
|
case ev.Type != gomatrixserverlib.MRoomMember:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var membership gomatrixserverlib.MemberContent
|
||||||
|
err = json.Unmarshal(ev.Content, &membership)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
continue
|
||||||
|
case membership.Membership == gomatrixserverlib.Join:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"appservice": appservice.ID,
|
||||||
|
"room_id": event.RoomID(),
|
||||||
|
}).WithError(err).Errorf("Unable to get membership for room")
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ const userIDExistsPath = "/users/"
|
||||||
// AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI
|
// AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI
|
||||||
type AppServiceQueryAPI struct {
|
type AppServiceQueryAPI struct {
|
||||||
HTTPClient *http.Client
|
HTTPClient *http.Client
|
||||||
Cfg *config.Dendrite
|
Cfg *config.AppServiceAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoomAliasExists performs a request to '/room/{roomAlias}' on all known
|
// RoomAliasExists performs a request to '/room/{roomAlias}' on all known
|
||||||
|
|
|
||||||
|
|
@ -1,30 +0,0 @@
|
||||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Database interface {
|
|
||||||
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error
|
|
||||||
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error)
|
|
||||||
CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error)
|
|
||||||
UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error
|
|
||||||
RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error
|
|
||||||
GetLatestTxnID(ctx context.Context) (int, error)
|
|
||||||
}
|
|
||||||
|
|
@ -1,256 +0,0 @@
|
||||||
// Copyright 2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const appserviceEventsSchema = `
|
|
||||||
-- Stores events to be sent to application services
|
|
||||||
CREATE TABLE IF NOT EXISTS appservice_events (
|
|
||||||
-- An auto-incrementing id unique to each event in the table
|
|
||||||
id BIGSERIAL NOT NULL PRIMARY KEY,
|
|
||||||
-- The ID of the application service the event will be sent to
|
|
||||||
as_id TEXT NOT NULL,
|
|
||||||
-- JSON representation of the event
|
|
||||||
headered_event_json TEXT NOT NULL,
|
|
||||||
-- The ID of the transaction that this event is a part of
|
|
||||||
txn_id BIGINT NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
|
|
||||||
`
|
|
||||||
|
|
||||||
const selectEventsByApplicationServiceIDSQL = "" +
|
|
||||||
"SELECT id, headered_event_json, txn_id " +
|
|
||||||
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
|
|
||||||
|
|
||||||
const countEventsByApplicationServiceIDSQL = "" +
|
|
||||||
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
|
|
||||||
|
|
||||||
const insertEventSQL = "" +
|
|
||||||
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
|
|
||||||
"VALUES ($1, $2, $3)"
|
|
||||||
|
|
||||||
const updateTxnIDForEventsSQL = "" +
|
|
||||||
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
|
|
||||||
|
|
||||||
const deleteEventsBeforeAndIncludingIDSQL = "" +
|
|
||||||
"DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2"
|
|
||||||
|
|
||||||
const (
|
|
||||||
// A transaction ID number that no transaction should ever have. Used for
|
|
||||||
// checking again the default value.
|
|
||||||
invalidTxnID = -2
|
|
||||||
)
|
|
||||||
|
|
||||||
type eventsStatements struct {
|
|
||||||
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
|
||||||
countEventsByApplicationServiceIDStmt *sql.Stmt
|
|
||||||
insertEventStmt *sql.Stmt
|
|
||||||
updateTxnIDForEventsStmt *sql.Stmt
|
|
||||||
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *eventsStatements) prepare(db *sql.DB) (err error) {
|
|
||||||
_, err = db.Exec(appserviceEventsSchema)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectEventsByApplicationServiceID takes in an application service ID and
|
|
||||||
// returns a slice of events that need to be sent to that application service,
|
|
||||||
// as well as an int later used to remove these same events from the database
|
|
||||||
// once successfully sent to an application service.
|
|
||||||
func (s *eventsStatements) selectEventsByApplicationServiceID(
|
|
||||||
ctx context.Context,
|
|
||||||
applicationServiceID string,
|
|
||||||
limit int,
|
|
||||||
) (
|
|
||||||
txnID, maxID int,
|
|
||||||
events []gomatrixserverlib.HeaderedEvent,
|
|
||||||
eventsRemaining bool,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"appservice": applicationServiceID,
|
|
||||||
}).WithError(err).Fatalf("appservice unable to select new events to send")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
// Retrieve events from the database. Unsuccessfully sent events first
|
|
||||||
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer checkNamedErr(eventRows.Close, &err)
|
|
||||||
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
|
|
||||||
func checkNamedErr(fn func() error, err *error) {
|
|
||||||
if e := fn(); e != nil && *err == nil {
|
|
||||||
*err = e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
|
|
||||||
// Get current time for use in calculating event age
|
|
||||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
|
||||||
|
|
||||||
// Iterate through each row and store event contents
|
|
||||||
// If txn_id changes dramatically, we've switched from collecting old events to
|
|
||||||
// new ones. Send back those events first.
|
|
||||||
lastTxnID := invalidTxnID
|
|
||||||
for eventsProcessed := 0; eventRows.Next(); {
|
|
||||||
var event gomatrixserverlib.HeaderedEvent
|
|
||||||
var eventJSON []byte
|
|
||||||
var id int
|
|
||||||
err = eventRows.Scan(
|
|
||||||
&id,
|
|
||||||
&eventJSON,
|
|
||||||
&txnID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, 0, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal eventJSON
|
|
||||||
if err = json.Unmarshal(eventJSON, &event); err != nil {
|
|
||||||
return nil, 0, 0, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If txnID has changed on this event from the previous event, then we've
|
|
||||||
// reached the end of a transaction's events. Return only those events.
|
|
||||||
if lastTxnID > invalidTxnID && lastTxnID != txnID {
|
|
||||||
return events, maxID, lastTxnID, true, nil
|
|
||||||
}
|
|
||||||
lastTxnID = txnID
|
|
||||||
|
|
||||||
// Limit events that aren't part of an old transaction
|
|
||||||
if txnID == -1 {
|
|
||||||
// Return if we've hit the limit
|
|
||||||
if eventsProcessed++; eventsProcessed > limit {
|
|
||||||
return events, maxID, lastTxnID, true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if id > maxID {
|
|
||||||
maxID = id
|
|
||||||
}
|
|
||||||
|
|
||||||
// Portion of the event that is unsigned due to rapid change
|
|
||||||
// TODO: Consider removing age as not many app services use it
|
|
||||||
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
|
|
||||||
return nil, 0, 0, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
events = append(events, event)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
|
|
||||||
// IDs into the db.
|
|
||||||
func (s *eventsStatements) countEventsByApplicationServiceID(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
) (int, error) {
|
|
||||||
var count int
|
|
||||||
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// insertEvent inserts an event mapped to its corresponding application service
|
|
||||||
// IDs into the db.
|
|
||||||
func (s *eventsStatements) insertEvent(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
event *gomatrixserverlib.HeaderedEvent,
|
|
||||||
) (err error) {
|
|
||||||
// Convert event to JSON before inserting
|
|
||||||
eventJSON, err := json.Marshal(event)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = s.insertEventStmt.ExecContext(
|
|
||||||
ctx,
|
|
||||||
appServiceID,
|
|
||||||
eventJSON,
|
|
||||||
-1, // No transaction ID yet
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
|
|
||||||
// before sending them to an AppService. Referenced before sending to make sure
|
|
||||||
// we aren't constructing multiple transactions with the same events.
|
|
||||||
func (s *eventsStatements) updateTxnIDForEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
appserviceID string,
|
|
||||||
maxID, txnID int,
|
|
||||||
) (err error) {
|
|
||||||
_, err = s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
|
|
||||||
func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
|
|
||||||
ctx context.Context,
|
|
||||||
appserviceID string,
|
|
||||||
eventTableID int,
|
|
||||||
) (err error) {
|
|
||||||
_, err = s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
@ -1,115 +0,0 @@
|
||||||
// Copyright 2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
// Import postgres database driver
|
|
||||||
_ "github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Database stores events intended to be later sent to application services
|
|
||||||
type Database struct {
|
|
||||||
events eventsStatements
|
|
||||||
txnID txnStatements
|
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
|
||||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) {
|
|
||||||
var result Database
|
|
||||||
var err error
|
|
||||||
if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = result.prepare(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) prepare() error {
|
|
||||||
if err := d.events.prepare(d.db); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return d.txnID.prepare(d.db)
|
|
||||||
}
|
|
||||||
|
|
||||||
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
|
|
||||||
// for a transaction worker to pull and later send to an application service.
|
|
||||||
func (d *Database) StoreEvent(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
event *gomatrixserverlib.HeaderedEvent,
|
|
||||||
) error {
|
|
||||||
return d.events.insertEvent(ctx, appServiceID, event)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
|
|
||||||
// be sent to an application service given its ID.
|
|
||||||
func (d *Database) GetEventsWithAppServiceID(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
limit int,
|
|
||||||
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
|
|
||||||
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CountEventsWithAppServiceID returns the number of events destined for an
|
|
||||||
// application service given its ID.
|
|
||||||
func (d *Database) CountEventsWithAppServiceID(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
) (int, error) {
|
|
||||||
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateTxnIDForEvents takes in an application service ID and a
|
|
||||||
// and stores them in the DB, unless the pair already exists, in
|
|
||||||
// which case it updates them.
|
|
||||||
func (d *Database) UpdateTxnIDForEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
appserviceID string,
|
|
||||||
maxID, txnID int,
|
|
||||||
) error {
|
|
||||||
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveEventsBeforeAndIncludingID removes all events from the database that
|
|
||||||
// are less than or equal to a given maximum ID. IDs here are implemented as a
|
|
||||||
// serial, thus this should always delete events in chronological order.
|
|
||||||
func (d *Database) RemoveEventsBeforeAndIncludingID(
|
|
||||||
ctx context.Context,
|
|
||||||
appserviceID string,
|
|
||||||
eventTableID int,
|
|
||||||
) error {
|
|
||||||
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLatestTxnID returns the latest available transaction id
|
|
||||||
func (d *Database) GetLatestTxnID(
|
|
||||||
ctx context.Context,
|
|
||||||
) (int, error) {
|
|
||||||
return d.txnID.selectTxnID(ctx)
|
|
||||||
}
|
|
||||||
|
|
@ -1,53 +0,0 @@
|
||||||
// Copyright 2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
const txnIDSchema = `
|
|
||||||
-- Keeps a count of the current transaction ID
|
|
||||||
CREATE SEQUENCE IF NOT EXISTS txn_id_counter START 1;
|
|
||||||
`
|
|
||||||
|
|
||||||
const selectTxnIDSQL = "SELECT nextval('txn_id_counter')"
|
|
||||||
|
|
||||||
type txnStatements struct {
|
|
||||||
selectTxnIDStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *txnStatements) prepare(db *sql.DB) (err error) {
|
|
||||||
_, err = db.Exec(txnIDSchema)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectTxnID selects the latest ascending transaction ID
|
|
||||||
func (s *txnStatements) selectTxnID(
|
|
||||||
ctx context.Context,
|
|
||||||
) (txnID int, err error) {
|
|
||||||
err = s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
@ -1,267 +0,0 @@
|
||||||
// Copyright 2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const appserviceEventsSchema = `
|
|
||||||
-- Stores events to be sent to application services
|
|
||||||
CREATE TABLE IF NOT EXISTS appservice_events (
|
|
||||||
-- An auto-incrementing id unique to each event in the table
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
-- The ID of the application service the event will be sent to
|
|
||||||
as_id TEXT NOT NULL,
|
|
||||||
-- JSON representation of the event
|
|
||||||
headered_event_json TEXT NOT NULL,
|
|
||||||
-- The ID of the transaction that this event is a part of
|
|
||||||
txn_id INTEGER NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
|
|
||||||
`
|
|
||||||
|
|
||||||
const selectEventsByApplicationServiceIDSQL = "" +
|
|
||||||
"SELECT id, headered_event_json, txn_id " +
|
|
||||||
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
|
|
||||||
|
|
||||||
const countEventsByApplicationServiceIDSQL = "" +
|
|
||||||
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
|
|
||||||
|
|
||||||
const insertEventSQL = "" +
|
|
||||||
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
|
|
||||||
"VALUES ($1, $2, $3)"
|
|
||||||
|
|
||||||
const updateTxnIDForEventsSQL = "" +
|
|
||||||
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
|
|
||||||
|
|
||||||
const deleteEventsBeforeAndIncludingIDSQL = "" +
|
|
||||||
"DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2"
|
|
||||||
|
|
||||||
const (
|
|
||||||
// A transaction ID number that no transaction should ever have. Used for
|
|
||||||
// checking again the default value.
|
|
||||||
invalidTxnID = -2
|
|
||||||
)
|
|
||||||
|
|
||||||
type eventsStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
|
||||||
countEventsByApplicationServiceIDStmt *sql.Stmt
|
|
||||||
insertEventStmt *sql.Stmt
|
|
||||||
updateTxnIDForEventsStmt *sql.Stmt
|
|
||||||
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
|
||||||
s.db = db
|
|
||||||
s.writer = writer
|
|
||||||
_, err = db.Exec(appserviceEventsSchema)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectEventsByApplicationServiceID takes in an application service ID and
|
|
||||||
// returns a slice of events that need to be sent to that application service,
|
|
||||||
// as well as an int later used to remove these same events from the database
|
|
||||||
// once successfully sent to an application service.
|
|
||||||
func (s *eventsStatements) selectEventsByApplicationServiceID(
|
|
||||||
ctx context.Context,
|
|
||||||
applicationServiceID string,
|
|
||||||
limit int,
|
|
||||||
) (
|
|
||||||
txnID, maxID int,
|
|
||||||
events []gomatrixserverlib.HeaderedEvent,
|
|
||||||
eventsRemaining bool,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"appservice": applicationServiceID,
|
|
||||||
}).WithError(err).Fatalf("appservice unable to select new events to send")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
// Retrieve events from the database. Unsuccessfully sent events first
|
|
||||||
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer checkNamedErr(eventRows.Close, &err)
|
|
||||||
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
|
|
||||||
func checkNamedErr(fn func() error, err *error) {
|
|
||||||
if e := fn(); e != nil && *err == nil {
|
|
||||||
*err = e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
|
|
||||||
// Get current time for use in calculating event age
|
|
||||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
|
||||||
|
|
||||||
// Iterate through each row and store event contents
|
|
||||||
// If txn_id changes dramatically, we've switched from collecting old events to
|
|
||||||
// new ones. Send back those events first.
|
|
||||||
lastTxnID := invalidTxnID
|
|
||||||
for eventsProcessed := 0; eventRows.Next(); {
|
|
||||||
var event gomatrixserverlib.HeaderedEvent
|
|
||||||
var eventJSON []byte
|
|
||||||
var id int
|
|
||||||
err = eventRows.Scan(
|
|
||||||
&id,
|
|
||||||
&eventJSON,
|
|
||||||
&txnID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, 0, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal eventJSON
|
|
||||||
if err = json.Unmarshal(eventJSON, &event); err != nil {
|
|
||||||
return nil, 0, 0, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If txnID has changed on this event from the previous event, then we've
|
|
||||||
// reached the end of a transaction's events. Return only those events.
|
|
||||||
if lastTxnID > invalidTxnID && lastTxnID != txnID {
|
|
||||||
return events, maxID, lastTxnID, true, nil
|
|
||||||
}
|
|
||||||
lastTxnID = txnID
|
|
||||||
|
|
||||||
// Limit events that aren't part of an old transaction
|
|
||||||
if txnID == -1 {
|
|
||||||
// Return if we've hit the limit
|
|
||||||
if eventsProcessed++; eventsProcessed > limit {
|
|
||||||
return events, maxID, lastTxnID, true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if id > maxID {
|
|
||||||
maxID = id
|
|
||||||
}
|
|
||||||
|
|
||||||
// Portion of the event that is unsigned due to rapid change
|
|
||||||
// TODO: Consider removing age as not many app services use it
|
|
||||||
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
|
|
||||||
return nil, 0, 0, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
events = append(events, event)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
|
|
||||||
// IDs into the db.
|
|
||||||
func (s *eventsStatements) countEventsByApplicationServiceID(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
) (int, error) {
|
|
||||||
var count int
|
|
||||||
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// insertEvent inserts an event mapped to its corresponding application service
|
|
||||||
// IDs into the db.
|
|
||||||
func (s *eventsStatements) insertEvent(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
event *gomatrixserverlib.HeaderedEvent,
|
|
||||||
) (err error) {
|
|
||||||
// Convert event to JSON before inserting
|
|
||||||
eventJSON, err := json.Marshal(event)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
|
||||||
_, err := s.insertEventStmt.ExecContext(
|
|
||||||
ctx,
|
|
||||||
appServiceID,
|
|
||||||
eventJSON,
|
|
||||||
-1, // No transaction ID yet
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
|
|
||||||
// before sending them to an AppService. Referenced before sending to make sure
|
|
||||||
// we aren't constructing multiple transactions with the same events.
|
|
||||||
func (s *eventsStatements) updateTxnIDForEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
appserviceID string,
|
|
||||||
maxID, txnID int,
|
|
||||||
) (err error) {
|
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
|
||||||
_, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
|
|
||||||
func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
|
|
||||||
ctx context.Context,
|
|
||||||
appserviceID string,
|
|
||||||
eventTableID int,
|
|
||||||
) (err error) {
|
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
|
||||||
_, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
@ -1,114 +0,0 @@
|
||||||
// Copyright 2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
// Import SQLite database driver
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Database stores events intended to be later sent to application services
|
|
||||||
type Database struct {
|
|
||||||
events eventsStatements
|
|
||||||
txnID txnStatements
|
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
|
||||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) {
|
|
||||||
var result Database
|
|
||||||
var err error
|
|
||||||
if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = result.prepare(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) prepare() error {
|
|
||||||
if err := d.events.prepare(d.db, d.writer); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return d.txnID.prepare(d.db, d.writer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
|
|
||||||
// for a transaction worker to pull and later send to an application service.
|
|
||||||
func (d *Database) StoreEvent(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
event *gomatrixserverlib.HeaderedEvent,
|
|
||||||
) error {
|
|
||||||
return d.events.insertEvent(ctx, appServiceID, event)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
|
|
||||||
// be sent to an application service given its ID.
|
|
||||||
func (d *Database) GetEventsWithAppServiceID(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
limit int,
|
|
||||||
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
|
|
||||||
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CountEventsWithAppServiceID returns the number of events destined for an
|
|
||||||
// application service given its ID.
|
|
||||||
func (d *Database) CountEventsWithAppServiceID(
|
|
||||||
ctx context.Context,
|
|
||||||
appServiceID string,
|
|
||||||
) (int, error) {
|
|
||||||
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateTxnIDForEvents takes in an application service ID and a
|
|
||||||
// and stores them in the DB, unless the pair already exists, in
|
|
||||||
// which case it updates them.
|
|
||||||
func (d *Database) UpdateTxnIDForEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
appserviceID string,
|
|
||||||
maxID, txnID int,
|
|
||||||
) error {
|
|
||||||
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveEventsBeforeAndIncludingID removes all events from the database that
|
|
||||||
// are less than or equal to a given maximum ID. IDs here are implemented as a
|
|
||||||
// serial, thus this should always delete events in chronological order.
|
|
||||||
func (d *Database) RemoveEventsBeforeAndIncludingID(
|
|
||||||
ctx context.Context,
|
|
||||||
appserviceID string,
|
|
||||||
eventTableID int,
|
|
||||||
) error {
|
|
||||||
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLatestTxnID returns the latest available transaction id
|
|
||||||
func (d *Database) GetLatestTxnID(
|
|
||||||
ctx context.Context,
|
|
||||||
) (int, error) {
|
|
||||||
return d.txnID.selectTxnID(ctx)
|
|
||||||
}
|
|
||||||
|
|
@ -1,82 +0,0 @@
|
||||||
// Copyright 2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
const txnIDSchema = `
|
|
||||||
-- Keeps a count of the current transaction ID
|
|
||||||
CREATE TABLE IF NOT EXISTS appservice_counters (
|
|
||||||
name TEXT PRIMARY KEY NOT NULL,
|
|
||||||
last_id INTEGER DEFAULT 1
|
|
||||||
);
|
|
||||||
INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1);
|
|
||||||
`
|
|
||||||
|
|
||||||
const selectTxnIDSQL = `
|
|
||||||
SELECT last_id FROM appservice_counters WHERE name='txn_id'
|
|
||||||
`
|
|
||||||
|
|
||||||
const updateTxnIDSQL = `
|
|
||||||
UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id'
|
|
||||||
`
|
|
||||||
|
|
||||||
type txnStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
selectTxnIDStmt *sql.Stmt
|
|
||||||
updateTxnIDStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
|
||||||
s.db = db
|
|
||||||
s.writer = writer
|
|
||||||
_, err = db.Exec(txnIDSchema)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.updateTxnIDStmt, err = db.Prepare(updateTxnIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectTxnID selects the latest ascending transaction ID
|
|
||||||
func (s *txnStatements) selectTxnID(
|
|
||||||
ctx context.Context,
|
|
||||||
) (txnID int, err error) {
|
|
||||||
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
|
||||||
err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = s.updateTxnIDStmt.ExecContext(ctx)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
@ -1,40 +0,0 @@
|
||||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
//go:build !wasm
|
|
||||||
// +build !wasm
|
|
||||||
|
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/appservice/storage/postgres"
|
|
||||||
"github.com/matrix-org/dendrite/appservice/storage/sqlite3"
|
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
|
||||||
// and sets DB connection parameters
|
|
||||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
|
|
||||||
switch {
|
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
|
||||||
return sqlite3.NewDatabase(base, dbProperties)
|
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
|
||||||
return postgres.NewDatabase(base, dbProperties)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected database type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/appservice/storage/sqlite3"
|
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
|
|
||||||
switch {
|
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
|
||||||
return sqlite3.NewDatabase(base, dbProperties)
|
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
|
||||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected database type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package types
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// AppServiceDeviceID is the AS dummy device ID
|
|
||||||
AppServiceDeviceID = "AS_Device"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ApplicationServiceWorkerState is a type that couples an application service,
|
|
||||||
// a lockable condition as well as some other state variables, allowing the
|
|
||||||
// roomserver to notify appservice workers when there are events ready to send
|
|
||||||
// externally to application services.
|
|
||||||
type ApplicationServiceWorkerState struct {
|
|
||||||
AppService config.ApplicationService
|
|
||||||
Cond *sync.Cond
|
|
||||||
// Events ready to be sent
|
|
||||||
EventsReady bool
|
|
||||||
// Backoff exponent (2^x secs). Max 6, aka 64s.
|
|
||||||
Backoff int
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotifyNewEvents wakes up all waiting goroutines, notifying that events remain
|
|
||||||
// in the event queue for this application service worker.
|
|
||||||
func (a *ApplicationServiceWorkerState) NotifyNewEvents() {
|
|
||||||
a.Cond.L.Lock()
|
|
||||||
a.EventsReady = true
|
|
||||||
a.Cond.Broadcast()
|
|
||||||
a.Cond.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// FinishEventProcessing marks all events of this worker as being sent to the
|
|
||||||
// application service.
|
|
||||||
func (a *ApplicationServiceWorkerState) FinishEventProcessing() {
|
|
||||||
a.Cond.L.Lock()
|
|
||||||
a.EventsReady = false
|
|
||||||
a.Cond.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitForNewEvents causes the calling goroutine to wait on the worker state's
|
|
||||||
// condition for a broadcast or similar wakeup, if there are no events ready.
|
|
||||||
func (a *ApplicationServiceWorkerState) WaitForNewEvents() {
|
|
||||||
a.Cond.L.Lock()
|
|
||||||
if !a.EventsReady {
|
|
||||||
a.Cond.Wait()
|
|
||||||
}
|
|
||||||
a.Cond.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
@ -1,236 +0,0 @@
|
||||||
// Copyright 2018 Vector Creations Ltd
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package workers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/appservice/storage"
|
|
||||||
"github.com/matrix-org/dendrite/appservice/types"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// Maximum size of events sent in each transaction.
|
|
||||||
transactionBatchSize = 50
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetupTransactionWorkers spawns a separate goroutine for each application
|
|
||||||
// service. Each of these "workers" handle taking all events intended for their
|
|
||||||
// app service, batch them up into a single transaction (up to a max transaction
|
|
||||||
// size), then send that off to the AS's /transactions/{txnID} endpoint. It also
|
|
||||||
// handles exponentially backing off in case the AS isn't currently available.
|
|
||||||
func SetupTransactionWorkers(
|
|
||||||
client *http.Client,
|
|
||||||
appserviceDB storage.Database,
|
|
||||||
workerStates []types.ApplicationServiceWorkerState,
|
|
||||||
) error {
|
|
||||||
// Create a worker that handles transmitting events to a single homeserver
|
|
||||||
for _, workerState := range workerStates {
|
|
||||||
// Don't create a worker if this AS doesn't want to receive events
|
|
||||||
if workerState.AppService.URL != "" {
|
|
||||||
go worker(client, appserviceDB, workerState)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// worker is a goroutine that sends any queued events to the application service
|
|
||||||
// it is given.
|
|
||||||
func worker(client *http.Client, db storage.Database, ws types.ApplicationServiceWorkerState) {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"appservice": ws.AppService.ID,
|
|
||||||
}).Info("Starting application service")
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Initial check for any leftover events to send from last time
|
|
||||||
eventCount, err := db.CountEventsWithAppServiceID(ctx, ws.AppService.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"appservice": ws.AppService.ID,
|
|
||||||
}).WithError(err).Fatal("appservice worker unable to read queued events from DB")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if eventCount > 0 {
|
|
||||||
ws.NotifyNewEvents()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop forever and keep waiting for more events to send
|
|
||||||
for {
|
|
||||||
// Wait for more events if we've sent all the events in the database
|
|
||||||
ws.WaitForNewEvents()
|
|
||||||
|
|
||||||
// Batch events up into a transaction
|
|
||||||
transactionJSON, txnID, maxEventID, eventsRemaining, err := createTransaction(ctx, db, ws.AppService.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"appservice": ws.AppService.ID,
|
|
||||||
}).WithError(err).Fatal("appservice worker unable to create transaction")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the events off to the application service
|
|
||||||
// Backoff if the application service does not respond
|
|
||||||
err = send(client, ws.AppService, txnID, transactionJSON)
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"appservice": ws.AppService.ID,
|
|
||||||
}).WithError(err).Error("unable to send event")
|
|
||||||
// Backoff
|
|
||||||
backoff(&ws, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// We sent successfully, hooray!
|
|
||||||
ws.Backoff = 0
|
|
||||||
|
|
||||||
// Transactions have a maximum event size, so there may still be some events
|
|
||||||
// left over to send. Keep sending until none are left
|
|
||||||
if !eventsRemaining {
|
|
||||||
ws.FinishEventProcessing()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove sent events from the DB
|
|
||||||
err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"appservice": ws.AppService.ID,
|
|
||||||
}).WithError(err).Fatal("unable to remove appservice events from the database")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// backoff pauses the calling goroutine for a 2^some backoff exponent seconds
|
|
||||||
func backoff(ws *types.ApplicationServiceWorkerState, err error) {
|
|
||||||
// Calculate how long to backoff for
|
|
||||||
backoffDuration := time.Duration(math.Pow(2, float64(ws.Backoff)))
|
|
||||||
backoffSeconds := time.Second * backoffDuration
|
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"appservice": ws.AppService.ID,
|
|
||||||
}).WithError(err).Warnf("unable to send transactions successfully, backing off for %ds",
|
|
||||||
backoffDuration)
|
|
||||||
|
|
||||||
ws.Backoff++
|
|
||||||
if ws.Backoff > 6 {
|
|
||||||
ws.Backoff = 6
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backoff
|
|
||||||
time.Sleep(backoffSeconds)
|
|
||||||
}
|
|
||||||
|
|
||||||
// createTransaction takes in a slice of AS events, stores them in an AS
|
|
||||||
// transaction, and JSON-encodes the results.
|
|
||||||
func createTransaction(
|
|
||||||
ctx context.Context,
|
|
||||||
db storage.Database,
|
|
||||||
appserviceID string,
|
|
||||||
) (
|
|
||||||
transactionJSON []byte,
|
|
||||||
txnID, maxID int,
|
|
||||||
eventsRemaining bool,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
// Retrieve the latest events from the DB (will return old events if they weren't successfully sent)
|
|
||||||
txnID, maxID, events, eventsRemaining, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize)
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"appservice": appserviceID,
|
|
||||||
}).WithError(err).Fatalf("appservice worker unable to read queued events from DB")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if these events do not already have a transaction ID
|
|
||||||
if txnID == -1 {
|
|
||||||
// If not, grab next available ID from the DB
|
|
||||||
txnID, err = db.GetLatestTxnID(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, 0, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark new events with current transactionID
|
|
||||||
if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil {
|
|
||||||
return nil, 0, 0, false, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var ev []*gomatrixserverlib.HeaderedEvent
|
|
||||||
for i := range events {
|
|
||||||
ev = append(ev, &events[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a transaction and store the events inside
|
|
||||||
transaction := gomatrixserverlib.ApplicationServiceTransaction{
|
|
||||||
Events: gomatrixserverlib.HeaderedToClientEvents(ev, gomatrixserverlib.FormatAll),
|
|
||||||
}
|
|
||||||
|
|
||||||
transactionJSON, err = json.Marshal(transaction)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// send sends events to an application service. Returns an error if an OK was not
|
|
||||||
// received back from the application service or the request timed out.
|
|
||||||
func send(
|
|
||||||
client *http.Client,
|
|
||||||
appservice config.ApplicationService,
|
|
||||||
txnID int,
|
|
||||||
transaction []byte,
|
|
||||||
) (err error) {
|
|
||||||
// PUT a transaction to our AS
|
|
||||||
// https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid
|
|
||||||
address := fmt.Sprintf("%s/transactions/%d?access_token=%s", appservice.URL, txnID, url.QueryEscape(appservice.HSToken))
|
|
||||||
req, err := http.NewRequest("PUT", address, bytes.NewBuffer(transaction))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer checkNamedErr(resp.Body.Close, &err)
|
|
||||||
|
|
||||||
// Check the AS received the events correctly
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
// TODO: Handle non-200 error codes from application services
|
|
||||||
return fmt.Errorf("non-OK status code %d returned from AS", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
|
|
||||||
func checkNamedErr(fn func() error, err *error) {
|
|
||||||
if e := fn(); e != nil && *err == nil {
|
|
||||||
*err = e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -255,7 +255,6 @@ func (m *DendriteMonolith) Start() {
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
|
||||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))
|
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))
|
||||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-federationsender.db", m.StorageDirectory, prefix))
|
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-federationsender.db", m.StorageDirectory, prefix))
|
||||||
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-appservice.db", m.StorageDirectory, prefix))
|
|
||||||
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
|
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
|
||||||
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
|
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
|
||||||
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
|
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,6 @@ func (m *DendriteMonolith) Start() {
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))
|
||||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-keyserver.db", m.StorageDirectory))
|
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-keyserver.db", m.StorageDirectory))
|
||||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-federationsender.db", m.StorageDirectory))
|
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-federationsender.db", m.StorageDirectory))
|
||||||
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-appservice.db", m.StorageDirectory))
|
|
||||||
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
|
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
|
||||||
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
|
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
|
||||||
cfg.ClientAPI.RegistrationDisabled = false
|
cfg.ClientAPI.RegistrationDisabled = false
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
#syntax=docker/dockerfile:1.2
|
||||||
|
|
||||||
FROM golang:1.18-stretch as build
|
FROM golang:1.18-stretch as build
|
||||||
RUN apt-get update && apt-get install -y sqlite3
|
RUN apt-get update && apt-get install -y sqlite3
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
@ -8,14 +10,12 @@ RUN mkdir /dendrite
|
||||||
|
|
||||||
# Utilise Docker caching when downloading dependencies, this stops us needlessly
|
# Utilise Docker caching when downloading dependencies, this stops us needlessly
|
||||||
# downloading dependencies every time.
|
# downloading dependencies every time.
|
||||||
COPY go.mod .
|
RUN --mount=target=. \
|
||||||
COPY go.sum .
|
--mount=type=cache,target=/go/pkg/mod \
|
||||||
RUN go mod download
|
--mount=type=cache,target=/root/.cache/go-build \
|
||||||
|
go build -o /dendrite ./cmd/generate-config && \
|
||||||
COPY . .
|
go build -o /dendrite ./cmd/generate-keys && \
|
||||||
RUN go build -o /dendrite ./cmd/dendrite-monolith-server
|
go build -o /dendrite ./cmd/dendrite-monolith-server
|
||||||
RUN go build -o /dendrite ./cmd/generate-keys
|
|
||||||
RUN go build -o /dendrite ./cmd/generate-config
|
|
||||||
|
|
||||||
WORKDIR /dendrite
|
WORKDIR /dendrite
|
||||||
RUN ./generate-keys --private-key matrix_key.pem
|
RUN ./generate-keys --private-key matrix_key.pem
|
||||||
|
|
@ -26,7 +26,7 @@ EXPOSE 8008 8448
|
||||||
|
|
||||||
# At runtime, generate TLS cert based on the CA now mounted at /ca
|
# At runtime, generate TLS cert based on the CA now mounted at /ca
|
||||||
# At runtime, replace the SERVER_NAME with what we are told
|
# At runtime, replace the SERVER_NAME with what we are told
|
||||||
CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
|
CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
|
||||||
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
|
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
|
||||||
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
||||||
./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}
|
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
#syntax=docker/dockerfile:1.2
|
||||||
|
|
||||||
# A local development Complement dockerfile, to be used with host mounts
|
# A local development Complement dockerfile, to be used with host mounts
|
||||||
# /cache -> Contains the entire dendrite code at Dockerfile build time. Builds binaries but only keeps the generate-* ones. Pre-compilation saves time.
|
# /cache -> Contains the entire dendrite code at Dockerfile build time. Builds binaries but only keeps the generate-* ones. Pre-compilation saves time.
|
||||||
# /dendrite -> Host-mounted sources
|
# /dendrite -> Host-mounted sources
|
||||||
|
|
@ -9,11 +11,10 @@
|
||||||
FROM golang:1.18-stretch
|
FROM golang:1.18-stretch
|
||||||
RUN apt-get update && apt-get install -y sqlite3
|
RUN apt-get update && apt-get install -y sqlite3
|
||||||
|
|
||||||
WORKDIR /runtime
|
|
||||||
|
|
||||||
ENV SERVER_NAME=localhost
|
ENV SERVER_NAME=localhost
|
||||||
EXPOSE 8008 8448
|
EXPOSE 8008 8448
|
||||||
|
|
||||||
|
WORKDIR /runtime
|
||||||
# This script compiles Dendrite for us.
|
# This script compiles Dendrite for us.
|
||||||
RUN echo '\
|
RUN echo '\
|
||||||
#!/bin/bash -eux \n\
|
#!/bin/bash -eux \n\
|
||||||
|
|
@ -29,25 +30,23 @@ RUN echo '\
|
||||||
RUN echo '\
|
RUN echo '\
|
||||||
#!/bin/bash -eu \n\
|
#!/bin/bash -eu \n\
|
||||||
./generate-keys --private-key matrix_key.pem \n\
|
./generate-keys --private-key matrix_key.pem \n\
|
||||||
./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\
|
./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\
|
||||||
./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\
|
./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\
|
||||||
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\
|
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\
|
||||||
./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\
|
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\
|
||||||
' > run.sh && chmod +x run.sh
|
' > run.sh && chmod +x run.sh
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /cache
|
WORKDIR /cache
|
||||||
# Pre-download deps; we don't need to do this if the GOPATH is mounted.
|
|
||||||
COPY go.mod .
|
|
||||||
COPY go.sum .
|
|
||||||
RUN go mod download
|
|
||||||
|
|
||||||
# Build the monolith in /cache - we won't actually use this but will rely on build artifacts to speed
|
# Build the monolith in /cache - we won't actually use this but will rely on build artifacts to speed
|
||||||
# up the real compilation. Build the generate-* binaries in the true /runtime locations.
|
# up the real compilation. Build the generate-* binaries in the true /runtime locations.
|
||||||
# If the generate-* source is changed, this dockerfile needs re-running.
|
# If the generate-* source is changed, this dockerfile needs re-running.
|
||||||
COPY . .
|
RUN --mount=target=. \
|
||||||
RUN go build ./cmd/dendrite-monolith-server && go build -o /runtime ./cmd/generate-keys && go build -o /runtime ./cmd/generate-config
|
--mount=type=cache,target=/go/pkg/mod \
|
||||||
|
--mount=type=cache,target=/root/.cache/go-build \
|
||||||
|
go build -o /runtime ./cmd/generate-config && \
|
||||||
|
go build -o /runtime ./cmd/generate-keys
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /runtime
|
WORKDIR /runtime
|
||||||
CMD /runtime/compile.sh && /runtime/run.sh
|
CMD /runtime/compile.sh && exec /runtime/run.sh
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
#syntax=docker/dockerfile:1.2
|
||||||
|
|
||||||
FROM golang:1.18-stretch as build
|
FROM golang:1.18-stretch as build
|
||||||
RUN apt-get update && apt-get install -y postgresql
|
RUN apt-get update && apt-get install -y postgresql
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
@ -26,14 +28,12 @@ RUN mkdir /dendrite
|
||||||
|
|
||||||
# Utilise Docker caching when downloading dependencies, this stops us needlessly
|
# Utilise Docker caching when downloading dependencies, this stops us needlessly
|
||||||
# downloading dependencies every time.
|
# downloading dependencies every time.
|
||||||
COPY go.mod .
|
RUN --mount=target=. \
|
||||||
COPY go.sum .
|
--mount=type=cache,target=/go/pkg/mod \
|
||||||
RUN go mod download
|
--mount=type=cache,target=/root/.cache/go-build \
|
||||||
|
go build -o /dendrite ./cmd/generate-config && \
|
||||||
COPY . .
|
go build -o /dendrite ./cmd/generate-keys && \
|
||||||
RUN go build -o /dendrite ./cmd/dendrite-monolith-server
|
go build -o /dendrite ./cmd/dendrite-monolith-server
|
||||||
RUN go build -o /dendrite ./cmd/generate-keys
|
|
||||||
RUN go build -o /dendrite ./cmd/generate-config
|
|
||||||
|
|
||||||
WORKDIR /dendrite
|
WORKDIR /dendrite
|
||||||
RUN ./generate-keys --private-key matrix_key.pem
|
RUN ./generate-keys --private-key matrix_key.pem
|
||||||
|
|
@ -45,10 +45,10 @@ EXPOSE 8008 8448
|
||||||
|
|
||||||
# At runtime, generate TLS cert based on the CA now mounted at /ca
|
# At runtime, generate TLS cert based on the CA now mounted at /ca
|
||||||
# At runtime, replace the SERVER_NAME with what we are told
|
# At runtime, replace the SERVER_NAME with what we are told
|
||||||
CMD /build/run_postgres.sh && ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
|
CMD /build/run_postgres.sh && ./generate-keys --keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
|
||||||
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
|
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
|
||||||
# Replace the connection string with a single postgres DB, using user/db = 'postgres' and no password, bump max_conns
|
# Replace the connection string with a single postgres DB, using user/db = 'postgres' and no password, bump max_conns
|
||||||
sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \
|
sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \
|
||||||
sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \
|
sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \
|
||||||
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
||||||
./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}
|
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}
|
||||||
|
|
@ -276,19 +276,19 @@ type recaptchaResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateUsername returns an error response if the username is invalid
|
// validateUsername returns an error response if the username is invalid
|
||||||
func validateUsername(username string) *util.JSONResponse {
|
func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
||||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||||
if len(username) > maxUsernameLength {
|
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||||
return &util.JSONResponse{
|
return &util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.BadJSON(fmt.Sprintf("'username' >%d characters", maxUsernameLength)),
|
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
||||||
}
|
}
|
||||||
} else if !validUsernameRegex.MatchString(username) {
|
} else if !validUsernameRegex.MatchString(localpart) {
|
||||||
return &util.JSONResponse{
|
return &util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
||||||
}
|
}
|
||||||
} else if username[0] == '_' { // Regex checks its not a zero length string
|
} else if localpart[0] == '_' { // Regex checks its not a zero length string
|
||||||
return &util.JSONResponse{
|
return &util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
|
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
|
||||||
|
|
@ -298,13 +298,13 @@ func validateUsername(username string) *util.JSONResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service
|
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service
|
||||||
func validateApplicationServiceUsername(username string) *util.JSONResponse {
|
func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
||||||
if len(username) > maxUsernameLength {
|
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||||
return &util.JSONResponse{
|
return &util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.BadJSON(fmt.Sprintf("'username' >%d characters", maxUsernameLength)),
|
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
||||||
}
|
}
|
||||||
} else if !validUsernameRegex.MatchString(username) {
|
} else if !validUsernameRegex.MatchString(localpart) {
|
||||||
return &util.JSONResponse{
|
return &util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
||||||
|
|
@ -523,7 +523,7 @@ func validateApplicationService(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check username application service is trying to register is valid
|
// Check username application service is trying to register is valid
|
||||||
if err := validateApplicationServiceUsername(username); err != nil {
|
if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -604,7 +604,7 @@ func Register(
|
||||||
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
|
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
|
||||||
// Spec-compliant case (the access_token is specified and the login type
|
// Spec-compliant case (the access_token is specified and the login type
|
||||||
// is correctly set, so it's an appservice registration)
|
// is correctly set, so it's an appservice registration)
|
||||||
if resErr := validateApplicationServiceUsername(r.Username); resErr != nil {
|
if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
case accessTokenErr == nil:
|
case accessTokenErr == nil:
|
||||||
|
|
@ -617,7 +617,7 @@ func Register(
|
||||||
default:
|
default:
|
||||||
// Spec-compliant case (neither the access_token nor the login type are
|
// Spec-compliant case (neither the access_token nor the login type are
|
||||||
// specified, so it's a normal user registration)
|
// specified, so it's a normal user registration)
|
||||||
if resErr := validateUsername(r.Username); resErr != nil {
|
if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1018,7 +1018,7 @@ func RegisterAvailable(
|
||||||
// Squash username to all lowercase letters
|
// Squash username to all lowercase letters
|
||||||
username = strings.ToLower(username)
|
username = strings.ToLower(username)
|
||||||
|
|
||||||
if err := validateUsername(username); err != nil {
|
if err := validateUsername(username, cfg.Matrix.ServerName); err != nil {
|
||||||
return *err
|
return *err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1059,7 +1059,7 @@ func RegisterAvailable(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSecretRegistration, req *http.Request) util.JSONResponse {
|
func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.ClientUserAPI, sr *SharedSecretRegistration, req *http.Request) util.JSONResponse {
|
||||||
ssrr, err := NewSharedSecretRegistrationRequest(req.Body)
|
ssrr, err := NewSharedSecretRegistrationRequest(req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
@ -1080,7 +1080,7 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec
|
||||||
// downcase capitals
|
// downcase capitals
|
||||||
ssrr.User = strings.ToLower(ssrr.User)
|
ssrr.User = strings.ToLower(ssrr.User)
|
||||||
|
|
||||||
if resErr := validateUsername(ssrr.User); resErr != nil {
|
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
if resErr := validatePassword(ssrr.Password); resErr != nil {
|
if resErr := validatePassword(ssrr.Password); resErr != nil {
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if req.Method == http.MethodPost {
|
if req.Method == http.MethodPost {
|
||||||
return handleSharedSecretRegistration(userAPI, sr, req)
|
return handleSharedSecretRegistration(cfg, userAPI, sr, req)
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusMethodNotAllowed,
|
Code: http.StatusMethodNotAllowed,
|
||||||
|
|
|
||||||
|
|
@ -66,10 +66,11 @@ var (
|
||||||
resetPassword = flag.Bool("reset-password", false, "Deprecated")
|
resetPassword = flag.Bool("reset-password", false, "Deprecated")
|
||||||
serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.")
|
serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.")
|
||||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||||
|
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
|
||||||
)
|
)
|
||||||
|
|
||||||
var cl = http.Client{
|
var cl = http.Client{
|
||||||
Timeout: time.Second * 10,
|
Timeout: time.Second * 30,
|
||||||
Transport: http.DefaultTransport,
|
Transport: http.DefaultTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -108,6 +109,8 @@ func main() {
|
||||||
logrus.Fatalln(err)
|
logrus.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cl.Timeout = *timeout
|
||||||
|
|
||||||
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
|
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalln("Failed to create the account:", err.Error())
|
logrus.Fatalln("Failed to create the account:", err.Error())
|
||||||
|
|
@ -124,8 +127,8 @@ type sharedSecretRegistrationRequest struct {
|
||||||
Admin bool `json:"admin"`
|
Admin bool `json:"admin"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, admin bool) (accesToken string, err error) {
|
func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, admin bool) (accessToken string, err error) {
|
||||||
registerURL := fmt.Sprintf("%s/_synapse/admin/v1/register", serverURL)
|
registerURL := fmt.Sprintf("%s/_synapse/admin/v1/register", strings.Trim(serverURL, "/"))
|
||||||
nonceReq, err := http.NewRequest(http.MethodGet, registerURL, nil)
|
nonceReq, err := http.NewRequest(http.MethodGet, registerURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to create http request: %w", err)
|
return "", fmt.Errorf("unable to create http request: %w", err)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
@ -42,6 +43,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup"
|
"github.com/matrix-org/dendrite/setup"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/userapi"
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
|
@ -70,31 +72,83 @@ func main() {
|
||||||
var pk ed25519.PublicKey
|
var pk ed25519.PublicKey
|
||||||
var sk ed25519.PrivateKey
|
var sk ed25519.PrivateKey
|
||||||
|
|
||||||
keyfile := *instanceName + ".key"
|
// iterate through the cli args and check if the config flag was set
|
||||||
if _, err := os.Stat(keyfile); os.IsNotExist(err) {
|
configFlagSet := false
|
||||||
if pk, sk, err = ed25519.GenerateKey(nil); err != nil {
|
for _, arg := range os.Args {
|
||||||
panic(err)
|
if arg == "--config" || arg == "-config" {
|
||||||
|
configFlagSet = true
|
||||||
|
break
|
||||||
}
|
}
|
||||||
if err = os.WriteFile(keyfile, sk, 0644); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
} else if err == nil {
|
|
||||||
if sk, err = os.ReadFile(keyfile); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
if len(sk) != ed25519.PrivateKeySize {
|
|
||||||
panic("the private key is not long enough")
|
|
||||||
}
|
|
||||||
pk = sk.Public().(ed25519.PublicKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cfg := &config.Dendrite{}
|
||||||
|
|
||||||
|
// use custom config if config flag is set
|
||||||
|
if configFlagSet {
|
||||||
|
cfg = setup.ParseFlags(true)
|
||||||
|
sk = cfg.Global.PrivateKey
|
||||||
|
} else {
|
||||||
|
keyfile := *instanceName + ".pem"
|
||||||
|
if _, err := os.Stat(keyfile); os.IsNotExist(err) {
|
||||||
|
oldkeyfile := *instanceName + ".key"
|
||||||
|
if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) {
|
||||||
|
if err = test.NewMatrixKey(keyfile); err != nil {
|
||||||
|
panic("failed to generate a new PEM key: " + err.Error())
|
||||||
|
}
|
||||||
|
if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil {
|
||||||
|
panic("failed to load PEM key: " + err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if sk, err = os.ReadFile(oldkeyfile); err != nil {
|
||||||
|
panic("failed to read the old private key: " + err.Error())
|
||||||
|
}
|
||||||
|
if len(sk) != ed25519.PrivateKeySize {
|
||||||
|
panic("the private key is not long enough")
|
||||||
|
}
|
||||||
|
if err := test.SaveMatrixKey(keyfile, sk); err != nil {
|
||||||
|
panic("failed to convert the private key to PEM format: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil {
|
||||||
|
panic("failed to load PEM key: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg.Defaults(true)
|
||||||
|
cfg.Global.PrivateKey = sk
|
||||||
|
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
||||||
|
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
||||||
|
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
||||||
|
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
||||||
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||||
|
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
|
||||||
|
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
|
||||||
|
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
|
||||||
|
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
|
||||||
|
cfg.ClientAPI.RegistrationDisabled = false
|
||||||
|
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
|
||||||
|
if err := cfg.Derive(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pk = sk.Public().(ed25519.PublicKey)
|
||||||
|
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
|
||||||
|
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||||
|
|
||||||
|
base := base.NewBaseDendrite(cfg, "Monolith")
|
||||||
|
defer base.Close() // nolint: errcheck
|
||||||
|
|
||||||
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||||
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||||
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
||||||
pManager := pineconeConnections.NewConnectionManager(pRouter, nil)
|
pManager := pineconeConnections.NewConnectionManager(pRouter, nil)
|
||||||
pMulticast.Start()
|
pMulticast.Start()
|
||||||
if instancePeer != nil && *instancePeer != "" {
|
if instancePeer != nil && *instancePeer != "" {
|
||||||
pManager.AddPeer(*instancePeer)
|
for _, peer := range strings.Split(*instancePeer, ",") {
|
||||||
|
pManager.AddPeer(strings.Trim(peer, " \t\r\n"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|
@ -125,29 +179,6 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
cfg := &config.Dendrite{}
|
|
||||||
cfg.Defaults(true)
|
|
||||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
|
|
||||||
cfg.Global.PrivateKey = sk
|
|
||||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
|
||||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
|
||||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
|
||||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
|
||||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
|
|
||||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
|
|
||||||
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
|
|
||||||
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
|
|
||||||
cfg.ClientAPI.RegistrationDisabled = false
|
|
||||||
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
|
|
||||||
if err := cfg.Derive(); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
base := base.NewBaseDendrite(cfg, "Monolith")
|
|
||||||
defer base.Close() // nolint: errcheck
|
|
||||||
|
|
||||||
federation := conn.CreateFederationClient(base, pQUIC)
|
federation := conn.CreateFederationClient(base, pQUIC)
|
||||||
|
|
||||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,6 @@ func main() {
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
|
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
|
||||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
|
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
|
||||||
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
|
|
||||||
cfg.MSCs.MSCs = []string{"msc2836"}
|
cfg.MSCs.MSCs = []string{"msc2836"}
|
||||||
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
|
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
|
||||||
cfg.ClientAPI.RegistrationDisabled = false
|
cfg.ClientAPI.RegistrationDisabled = false
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ func main() {
|
||||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName)
|
cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName)
|
||||||
}
|
}
|
||||||
if *dbURI != "" {
|
if *dbURI != "" {
|
||||||
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(*dbURI)
|
|
||||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource(*dbURI)
|
cfg.FederationAPI.Database.ConnectionString = config.DataSource(*dbURI)
|
||||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(*dbURI)
|
cfg.KeyServer.Database.ConnectionString = config.DataSource(*dbURI)
|
||||||
cfg.MSCs.Database.ConnectionString = config.DataSource(*dbURI)
|
cfg.MSCs.Database.ConnectionString = config.DataSource(*dbURI)
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ var (
|
||||||
authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
|
authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
|
||||||
authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
|
authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
|
||||||
serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.")
|
serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.")
|
||||||
|
keySize = flag.Int("keysize", 4096, "Optional: Create TLS RSA private key with the given key size")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
@ -58,12 +59,12 @@ func main() {
|
||||||
log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied")
|
log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied")
|
||||||
}
|
}
|
||||||
if *authorityCertFile == "" && *authorityKeyFile == "" {
|
if *authorityCertFile == "" && *authorityKeyFile == "" {
|
||||||
if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil {
|
if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile, *keySize); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// generate the TLS cert/key based on the authority given.
|
// generate the TLS cert/key based on the authority given.
|
||||||
if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile); err != nil {
|
if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile, *keySize); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -132,13 +132,6 @@ app_service_api:
|
||||||
listen: http://[::]:7777 # The listen address for incoming API requests
|
listen: http://[::]:7777 # The listen address for incoming API requests
|
||||||
connect: http://app_service_api:7777 # The connect address for other components to use
|
connect: http://app_service_api:7777 # The connect address for other components to use
|
||||||
|
|
||||||
# Database configuration for this component.
|
|
||||||
database:
|
|
||||||
connection_string: postgresql://username:password@hostname/dendrite_appservice?sslmode=disable
|
|
||||||
max_open_conns: 10
|
|
||||||
max_idle_conns: 2
|
|
||||||
conn_max_lifetime: -1
|
|
||||||
|
|
||||||
# Disable the validation of TLS certificates of appservices. This is
|
# Disable the validation of TLS certificates of appservices. This is
|
||||||
# not recommended in production since it may allow appservice traffic
|
# not recommended in production since it may allow appservice traffic
|
||||||
# to be sent to an insecure endpoint.
|
# to be sent to an insecure endpoint.
|
||||||
|
|
|
||||||
|
|
@ -67,14 +67,15 @@ func NewKeyChangeConsumer(
|
||||||
// Start consuming from key servers
|
// Start consuming from key servers
|
||||||
func (t *KeyChangeConsumer) Start() error {
|
func (t *KeyChangeConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
|
t.ctx, t.jetstream, t.topic, t.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
t.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage is called in response to a message received on the
|
// onMessage is called in response to a message received on the
|
||||||
// key change events topic from the key server.
|
// key change events topic from the key server.
|
||||||
func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (t *KeyChangeConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
var m api.DeviceMessage
|
var m api.DeviceMessage
|
||||||
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
||||||
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
||||||
|
|
|
||||||
|
|
@ -69,14 +69,15 @@ func (t *OutputPresenceConsumer) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
|
t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
|
||||||
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
|
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage is called in response to a message received on the presence
|
// onMessage is called in response to a message received on the presence
|
||||||
// events topic from the client api.
|
// events topic from the client api.
|
||||||
func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
// only send presence events which originated from us
|
// only send presence events which originated from us
|
||||||
userID := msg.Header.Get(jetstream.UserID)
|
userID := msg.Header.Get(jetstream.UserID)
|
||||||
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
|
|
||||||
|
|
@ -65,14 +65,15 @@ func NewOutputReceiptConsumer(
|
||||||
// Start consuming from the clientapi
|
// Start consuming from the clientapi
|
||||||
func (t *OutputReceiptConsumer) Start() error {
|
func (t *OutputReceiptConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
|
t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
|
||||||
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
|
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage is called in response to a message received on the receipt
|
// onMessage is called in response to a message received on the receipt
|
||||||
// events topic from the client api.
|
// events topic from the client api.
|
||||||
func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
receipt := syncTypes.OutputReceiptEvent{
|
receipt := syncTypes.OutputReceiptEvent{
|
||||||
UserID: msg.Header.Get(jetstream.UserID),
|
UserID: msg.Header.Get(jetstream.UserID),
|
||||||
RoomID: msg.Header.Get(jetstream.RoomID),
|
RoomID: msg.Header.Get(jetstream.RoomID),
|
||||||
|
|
|
||||||
|
|
@ -68,8 +68,8 @@ func NewOutputRoomEventConsumer(
|
||||||
// Start consuming from room servers
|
// Start consuming from room servers
|
||||||
func (s *OutputRoomEventConsumer) Start() error {
|
func (s *OutputRoomEventConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -77,7 +77,8 @@ func (s *OutputRoomEventConsumer) Start() error {
|
||||||
// It is unsafe to call this with messages for the same room in multiple gorountines
|
// It is unsafe to call this with messages for the same room in multiple gorountines
|
||||||
// because updates it will likely fail with a types.EventIDMismatchError when it
|
// because updates it will likely fail with a types.EventIDMismatchError when it
|
||||||
// realises that it cannot update the room state using the deltas.
|
// realises that it cannot update the room state using the deltas.
|
||||||
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
// Parse out the event JSON
|
// Parse out the event JSON
|
||||||
var output api.OutputEvent
|
var output api.OutputEvent
|
||||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -63,14 +63,15 @@ func NewOutputSendToDeviceConsumer(
|
||||||
// Start consuming from the client api
|
// Start consuming from the client api
|
||||||
func (t *OutputSendToDeviceConsumer) Start() error {
|
func (t *OutputSendToDeviceConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
|
t.ctx, t.jetstream, t.topic, t.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
t.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage is called in response to a message received on the
|
// onMessage is called in response to a message received on the
|
||||||
// send-to-device events topic from the client api.
|
// send-to-device events topic from the client api.
|
||||||
func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
// only send send-to-device events which originated from us
|
// only send send-to-device events which originated from us
|
||||||
sender := msg.Header.Get("sender")
|
sender := msg.Header.Get("sender")
|
||||||
_, originServerName, err := gomatrixserverlib.SplitID('@', sender)
|
_, originServerName, err := gomatrixserverlib.SplitID('@', sender)
|
||||||
|
|
|
||||||
|
|
@ -62,14 +62,15 @@ func NewOutputTypingConsumer(
|
||||||
// Start consuming from the clientapi
|
// Start consuming from the clientapi
|
||||||
func (t *OutputTypingConsumer) Start() error {
|
func (t *OutputTypingConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
|
t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
|
||||||
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
|
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage is called in response to a message received on the typing
|
// onMessage is called in response to a message received on the typing
|
||||||
// events topic from the client api.
|
// events topic from the client api.
|
||||||
func (t *OutputTypingConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
// Extract the typing event from msg.
|
// Extract the typing event from msg.
|
||||||
roomID := msg.Header.Get(jetstream.RoomID)
|
roomID := msg.Header.Get(jetstream.RoomID)
|
||||||
userID := msg.Header.Get(jetstream.UserID)
|
userID := msg.Header.Get(jetstream.UserID)
|
||||||
|
|
|
||||||
|
|
@ -329,6 +329,12 @@ func SendJoin(
|
||||||
JSON: jsonerror.NotFound("Room does not exist"),
|
JSON: jsonerror.NotFound("Room does not exist"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !stateAndAuthChainResponse.StateKnown {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden("State not known"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the user is already in the room. If they're already in then
|
// Check if the user is already in the room. If they're already in then
|
||||||
// there isn't much point in sending another join event into the room.
|
// there isn't much point in sending another join event into the room.
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,12 @@ func getState(
|
||||||
return nil, nil, &resErr
|
return nil, nil, &resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !response.StateKnown {
|
||||||
|
return nil, nil, &util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: jsonerror.NotFound("State not known"),
|
||||||
|
}
|
||||||
|
}
|
||||||
if response.IsRejected {
|
if response.IsRejected {
|
||||||
return nil, nil, &util.JSONResponse{
|
return nil, nil, &util.JSONResponse{
|
||||||
Code: http.StatusNotFound,
|
Code: http.StatusNotFound,
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,11 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Statistics contains information about all of the remote federated
|
// Statistics contains information about all of the remote federated
|
||||||
|
|
@ -126,13 +127,13 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
until, ok := s.backoffUntil.Load().(time.Time)
|
until, ok := s.backoffUntil.Load().(time.Time)
|
||||||
if ok {
|
if ok && !until.IsZero() {
|
||||||
select {
|
select {
|
||||||
case <-time.After(time.Until(until)):
|
case <-time.After(time.Until(until)):
|
||||||
case <-s.interrupt:
|
case <-s.interrupt:
|
||||||
}
|
}
|
||||||
|
s.backoffStarted.Store(false)
|
||||||
}
|
}
|
||||||
s.backoffStarted.Store(false)
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -110,6 +110,7 @@ func (d *Database) GetPendingEDUs(
|
||||||
return fmt.Errorf("json.Unmarshal: %w", err)
|
return fmt.Errorf("json.Unmarshal: %w", err)
|
||||||
}
|
}
|
||||||
edus[&Receipt{nid}] = &event
|
edus[&Receipt{nid}] = &event
|
||||||
|
d.Cache.StoreFederationQueuedEDU(nid, &event)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -177,20 +178,18 @@ func (d *Database) GetPendingEDUServerNames(
|
||||||
return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil)
|
return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteExpiredEDUs deletes expired EDUs
|
// DeleteExpiredEDUs deletes expired EDUs and evicts them from the cache.
|
||||||
func (d *Database) DeleteExpiredEDUs(ctx context.Context) error {
|
func (d *Database) DeleteExpiredEDUs(ctx context.Context) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
var jsonNIDs []int64
|
||||||
|
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) (err error) {
|
||||||
expiredBefore := gomatrixserverlib.AsTimestamp(time.Now())
|
expiredBefore := gomatrixserverlib.AsTimestamp(time.Now())
|
||||||
jsonNIDs, err := d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore)
|
jsonNIDs, err = d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(jsonNIDs) == 0 {
|
if len(jsonNIDs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for i := range jsonNIDs {
|
|
||||||
d.Cache.EvictFederationQueuedEDU(jsonNIDs[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil {
|
if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -198,4 +197,14 @@ func (d *Database) DeleteExpiredEDUs(ctx context.Context) error {
|
||||||
|
|
||||||
return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore)
|
return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range jsonNIDs {
|
||||||
|
d.Cache.EvictFederationQueuedEDU(jsonNIDs[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat
|
||||||
|
|
||||||
func TestExpireEDUs(t *testing.T) {
|
func TestExpireEDUs(t *testing.T) {
|
||||||
var expireEDUTypes = map[string]time.Duration{
|
var expireEDUTypes = map[string]time.Duration{
|
||||||
gomatrixserverlib.MReceipt: time.Millisecond,
|
gomatrixserverlib.MReceipt: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
|
||||||
2
go.mod
2
go.mod
|
|
@ -21,7 +21,7 @@ require (
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2
|
||||||
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9
|
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
github.com/mattn/go-sqlite3 v1.14.13
|
github.com/mattn/go-sqlite3 v1.14.13
|
||||||
|
|
|
||||||
4
go.sum
4
go.sum
|
|
@ -343,8 +343,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c h1:GhKmb8s9iXA9qsFD1SbiRo6Ee7cnbfcgJQ/iy43wczM=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2 h1:esbNn9hg//tAStA6TogatAJAursw23A+yfVRQsdiv70=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
|
||||||
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY=
|
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY=
|
||||||
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE=
|
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE=
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ var build string
|
||||||
const (
|
const (
|
||||||
VersionMajor = 0
|
VersionMajor = 0
|
||||||
VersionMinor = 9
|
VersionMinor = 9
|
||||||
VersionPatch = 3
|
VersionPatch = 5
|
||||||
VersionTag = "" // example: "rc1"
|
VersionTag = "" // example: "rc1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,14 +55,15 @@ func NewDeviceListUpdateConsumer(
|
||||||
// Start consuming from key servers
|
// Start consuming from key servers
|
||||||
func (t *DeviceListUpdateConsumer) Start() error {
|
func (t *DeviceListUpdateConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
|
t.ctx, t.jetstream, t.topic, t.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
t.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage is called in response to a message received on the
|
// onMessage is called in response to a message received on the
|
||||||
// key change events topic from the key server.
|
// key change events topic from the key server.
|
||||||
func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
var m gomatrixserverlib.DeviceListUpdateEvent
|
var m gomatrixserverlib.DeviceListUpdateEvent
|
||||||
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to read from device list update input topic")
|
logrus.WithError(err).Errorf("Failed to read from device list update input topic")
|
||||||
|
|
|
||||||
|
|
@ -335,8 +335,9 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
|
||||||
retriesMu := &sync.Mutex{}
|
retriesMu := &sync.Mutex{}
|
||||||
// restarter goroutine which will inject failed servers into ch when it is time
|
// restarter goroutine which will inject failed servers into ch when it is time
|
||||||
go func() {
|
go func() {
|
||||||
|
var serversToRetry []gomatrixserverlib.ServerName
|
||||||
for {
|
for {
|
||||||
var serversToRetry []gomatrixserverlib.ServerName
|
serversToRetry = serversToRetry[:0] // reuse memory
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
retriesMu.Lock()
|
retriesMu.Lock()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
@ -355,11 +356,17 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for serverName := range ch {
|
for serverName := range ch {
|
||||||
|
retriesMu.Lock()
|
||||||
|
_, exists := retries[serverName]
|
||||||
|
retriesMu.Unlock()
|
||||||
|
if exists {
|
||||||
|
// Don't retry a server that we're already waiting for.
|
||||||
|
continue
|
||||||
|
}
|
||||||
waitTime, shouldRetry := u.processServer(serverName)
|
waitTime, shouldRetry := u.processServer(serverName)
|
||||||
if shouldRetry {
|
if shouldRetry {
|
||||||
retriesMu.Lock()
|
retriesMu.Lock()
|
||||||
_, exists := retries[serverName]
|
if _, exists = retries[serverName]; !exists {
|
||||||
if !exists {
|
|
||||||
retries[serverName] = time.Now().Add(waitTime)
|
retries[serverName] = time.Now().Add(waitTime)
|
||||||
}
|
}
|
||||||
retriesMu.Unlock()
|
retriesMu.Unlock()
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ func NewInternalAPI(
|
||||||
updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
|
updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
|
||||||
ap.Updater = updater
|
ap.Updater = updater
|
||||||
go func() {
|
go func() {
|
||||||
if err = updater.Start(); err != nil {
|
if err := updater.Start(); err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to start device list updater")
|
logrus.WithError(err).Panicf("failed to start device list updater")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
@ -68,7 +68,7 @@ func NewInternalAPI(
|
||||||
dlConsumer := consumers.NewDeviceListUpdateConsumer(
|
dlConsumer := consumers.NewDeviceListUpdateConsumer(
|
||||||
base.ProcessContext, cfg, js, updater,
|
base.ProcessContext, cfg, js, updater,
|
||||||
)
|
)
|
||||||
if err = dlConsumer.Start(); err != nil {
|
if err := dlConsumer.Start(); err != nil {
|
||||||
logrus.WithError(err).Panic("failed to start device list consumer")
|
logrus.WithError(err).Panic("failed to start device list consumer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PerformErrorCode int
|
type PerformErrorCode int
|
||||||
|
|
@ -161,7 +162,8 @@ func (r *PerformBackfillRequest) PrevEventIDs() []string {
|
||||||
// PerformBackfillResponse is a response to PerformBackfill.
|
// PerformBackfillResponse is a response to PerformBackfill.
|
||||||
type PerformBackfillResponse struct {
|
type PerformBackfillResponse struct {
|
||||||
// Missing events, arbritrary order.
|
// Missing events, arbritrary order.
|
||||||
Events []*gomatrixserverlib.HeaderedEvent `json:"events"`
|
Events []*gomatrixserverlib.HeaderedEvent `json:"events"`
|
||||||
|
HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformPublishRequest struct {
|
type PerformPublishRequest struct {
|
||||||
|
|
|
||||||
|
|
@ -227,6 +227,7 @@ type QueryStateAndAuthChainResponse struct {
|
||||||
// Do all the previous events exist on this roomserver?
|
// Do all the previous events exist on this roomserver?
|
||||||
// If some of previous events do not exist this will be false and StateEvents will be empty.
|
// If some of previous events do not exist this will be false and StateEvents will be empty.
|
||||||
PrevEventsExist bool `json:"prev_events_exist"`
|
PrevEventsExist bool `json:"prev_events_exist"`
|
||||||
|
StateKnown bool `json:"state_known"`
|
||||||
// The state and auth chain events that were requested.
|
// The state and auth chain events that were requested.
|
||||||
// The lists will be in an arbitrary order.
|
// The lists will be in an arbitrary order.
|
||||||
StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"`
|
StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"`
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SendEvents to the roomserver The events are written with KindNew.
|
// SendEvents to the roomserver The events are written with KindNew.
|
||||||
|
|
@ -69,6 +70,13 @@ func SendEventWithState(
|
||||||
stateEventIDs[i] = stateEvents[i].EventID()
|
stateEventIDs[i] = stateEvents[i].EventID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||||
|
"room_id": event.RoomID(),
|
||||||
|
"event_id": event.EventID(),
|
||||||
|
"outliers": len(ires),
|
||||||
|
"state_ids": len(stateEventIDs),
|
||||||
|
}).Infof("Submitting %q event to roomserver with state snapshot", event.Type())
|
||||||
|
|
||||||
ires = append(ires, InputRoomEvent{
|
ires = append(ires, InputRoomEvent{
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Event: event,
|
Event: event,
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ func CheckForSoftFail(
|
||||||
var authStateEntries []types.StateEntry
|
var authStateEntries []types.StateEntry
|
||||||
var err error
|
var err error
|
||||||
if rewritesState {
|
if rewritesState {
|
||||||
authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs)
|
authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
|
return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -97,7 +97,7 @@ func CheckAuthEvents(
|
||||||
authEventIDs []string,
|
authEventIDs []string,
|
||||||
) ([]types.EventNID, error) {
|
) ([]types.EventNID, error) {
|
||||||
// Grab the numeric IDs for the supplied auth state events from the database.
|
// Grab the numeric IDs for the supplied auth state events from the database.
|
||||||
authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
|
authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("db.StateEntriesForEventIDs: %w", err)
|
return nil, fmt.Errorf("db.StateEntriesForEventIDs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -254,8 +254,15 @@ func CheckServerAllowedToSeeEvent(
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// Something else went wrong
|
switch err.(type) {
|
||||||
return false, err
|
case types.MissingStateError:
|
||||||
|
// If there's no state then we assume it's open visibility, as Synapse does:
|
||||||
|
// https://github.com/matrix-org/synapse/blob/aec87a0f9369a3015b2a53469f88d1de274e8b71/synapse/visibility.py#L654-L655
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
// Something else went wrong
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
|
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/internal/query"
|
"github.com/matrix-org/dendrite/roomserver/internal/query"
|
||||||
"github.com/matrix-org/dendrite/roomserver/producers"
|
"github.com/matrix-org/dendrite/roomserver/producers"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
|
|
@ -247,14 +248,24 @@ func (w *worker) _next() {
|
||||||
// it was a synchronous request.
|
// it was a synchronous request.
|
||||||
var errString string
|
var errString string
|
||||||
if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil {
|
if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil {
|
||||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
switch err.(type) {
|
||||||
sentry.CaptureException(err)
|
case types.RejectedError:
|
||||||
|
// Don't send events that were rejected to Sentry
|
||||||
|
logrus.WithError(err).WithFields(logrus.Fields{
|
||||||
|
"room_id": w.roomID,
|
||||||
|
"event_id": inputRoomEvent.Event.EventID(),
|
||||||
|
"type": inputRoomEvent.Event.Type(),
|
||||||
|
}).Warn("Roomserver rejected event")
|
||||||
|
default:
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
}
|
||||||
|
logrus.WithError(err).WithFields(logrus.Fields{
|
||||||
|
"room_id": w.roomID,
|
||||||
|
"event_id": inputRoomEvent.Event.EventID(),
|
||||||
|
"type": inputRoomEvent.Event.Type(),
|
||||||
|
}).Warn("Roomserver failed to process event")
|
||||||
}
|
}
|
||||||
logrus.WithError(err).WithFields(logrus.Fields{
|
|
||||||
"room_id": w.roomID,
|
|
||||||
"event_id": inputRoomEvent.Event.EventID(),
|
|
||||||
"type": inputRoomEvent.Event.Type(),
|
|
||||||
}).Warn("Roomserver failed to process async event")
|
|
||||||
_ = msg.Term()
|
_ = msg.Term()
|
||||||
errString = err.Error()
|
errString = err.Error()
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -301,7 +301,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
// bother doing this if the event was already rejected as it just ends up
|
// bother doing this if the event was already rejected as it just ends up
|
||||||
// burning CPU time.
|
// burning CPU time.
|
||||||
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
|
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
|
||||||
if rejectionErr == nil && !isRejected && !softfail {
|
if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected {
|
||||||
var err error
|
var err error
|
||||||
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
|
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -313,7 +313,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the event.
|
// Store the event.
|
||||||
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected || softfail)
|
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("updater.StoreEvent: %w", err)
|
return fmt.Errorf("updater.StoreEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -353,12 +353,18 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it.
|
// We stop here if the event is rejected: We've stored it but won't update
|
||||||
if isRejected || softfail {
|
// forward extremities or notify downstream components about it.
|
||||||
logger.WithError(rejectionErr).WithFields(logrus.Fields{
|
switch {
|
||||||
"soft_fail": softfail,
|
case isRejected:
|
||||||
"missing_prev": missingPrev,
|
logger.WithError(rejectionErr).Warn("Stored rejected event")
|
||||||
}).Warn("Stored rejected event")
|
if rejectionErr != nil {
|
||||||
|
return types.RejectedError(rejectionErr.Error())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case softfail:
|
||||||
|
logger.WithError(rejectionErr).Warn("Stored soft-failed event")
|
||||||
if rejectionErr != nil {
|
if rejectionErr != nil {
|
||||||
return types.RejectedError(rejectionErr.Error())
|
return types.RejectedError(rejectionErr.Error())
|
||||||
}
|
}
|
||||||
|
|
@ -661,7 +667,7 @@ func (r *Inputer) calculateAndSetState(
|
||||||
// We've been told what the state at the event is so we don't need to calculate it.
|
// We've been told what the state at the event is so we don't need to calculate it.
|
||||||
// Check that those state events are in the database and store the state.
|
// Check that those state events are in the database and store the state.
|
||||||
var entries []types.StateEntry
|
var entries []types.StateEntry
|
||||||
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
|
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs, true); err != nil {
|
||||||
return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
|
return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
|
||||||
}
|
}
|
||||||
entries = types.DeduplicateStateEntries(entries)
|
entries = types.DeduplicateStateEntries(entries)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
@ -140,11 +139,11 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var entries []types.StateEntry
|
var entries []types.StateEntry
|
||||||
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
|
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil {
|
||||||
// attempt to fetch the missing events
|
// attempt to fetch the missing events
|
||||||
r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs)
|
r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs)
|
||||||
// try again
|
// try again
|
||||||
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
|
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
|
||||||
return err
|
return err
|
||||||
|
|
@ -164,6 +163,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
||||||
// TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
|
// TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
|
||||||
|
|
||||||
res.Events = events
|
res.Events = events
|
||||||
|
res.HistoryVisibility = requester.historyVisiblity
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -248,6 +248,7 @@ type backfillRequester struct {
|
||||||
servers []gomatrixserverlib.ServerName
|
servers []gomatrixserverlib.ServerName
|
||||||
eventIDToBeforeStateIDs map[string][]string
|
eventIDToBeforeStateIDs map[string][]string
|
||||||
eventIDMap map[string]*gomatrixserverlib.Event
|
eventIDMap map[string]*gomatrixserverlib.Event
|
||||||
|
historyVisiblity gomatrixserverlib.HistoryVisibility
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBackfillRequester(
|
func newBackfillRequester(
|
||||||
|
|
@ -266,6 +267,7 @@ func newBackfillRequester(
|
||||||
eventIDMap: make(map[string]*gomatrixserverlib.Event),
|
eventIDMap: make(map[string]*gomatrixserverlib.Event),
|
||||||
bwExtrems: bwExtrems,
|
bwExtrems: bwExtrems,
|
||||||
preferServer: preferServer,
|
preferServer: preferServer,
|
||||||
|
historyVisiblity: gomatrixserverlib.HistoryVisibilityShared,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -317,7 +319,6 @@ FederationHit:
|
||||||
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res
|
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
sentry.CaptureException(lastErr) // temporary to see if we might need to raise the server limit
|
|
||||||
return nil, lastErr
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -395,7 +396,6 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
sentry.CaptureException(lastErr) // temporary to see if we might need to raise the server limit
|
|
||||||
return nil, lastErr
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -447,7 +447,8 @@ FindSuccessor:
|
||||||
}
|
}
|
||||||
|
|
||||||
// possibly return all joined servers depending on history visiblity
|
// possibly return all joined servers depending on history visiblity
|
||||||
memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer)
|
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer)
|
||||||
|
b.historyVisiblity = visibility
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -528,7 +529,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
|
||||||
// pull all events and then filter by that table.
|
// pull all events and then filter by that table.
|
||||||
func joinEventsFromHistoryVisibility(
|
func joinEventsFromHistoryVisibility(
|
||||||
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry,
|
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry,
|
||||||
thisServer gomatrixserverlib.ServerName) ([]types.Event, error) {
|
thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
|
||||||
|
|
||||||
var eventNIDs []types.EventNID
|
var eventNIDs []types.EventNID
|
||||||
for _, entry := range stateEntries {
|
for _, entry := range stateEntries {
|
||||||
|
|
@ -542,7 +543,9 @@ func joinEventsFromHistoryVisibility(
|
||||||
// Get all of the events in this state
|
// Get all of the events in this state
|
||||||
stateEvents, err := db.Events(ctx, eventNIDs)
|
stateEvents, err := db.Events(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
// even though the default should be shared, restricting the visibility to joined
|
||||||
|
// feels more secure here.
|
||||||
|
return nil, gomatrixserverlib.HistoryVisibilityJoined, err
|
||||||
}
|
}
|
||||||
events := make([]*gomatrixserverlib.Event, len(stateEvents))
|
events := make([]*gomatrixserverlib.Event, len(stateEvents))
|
||||||
for i := range stateEvents {
|
for i := range stateEvents {
|
||||||
|
|
@ -551,20 +554,22 @@ func joinEventsFromHistoryVisibility(
|
||||||
|
|
||||||
// Can we see events in the room?
|
// Can we see events in the room?
|
||||||
canSeeEvents := auth.IsServerAllowed(thisServer, true, events)
|
canSeeEvents := auth.IsServerAllowed(thisServer, true, events)
|
||||||
|
visibility := gomatrixserverlib.HistoryVisibility(auth.HistoryVisibilityForRoom(events))
|
||||||
if !canSeeEvents {
|
if !canSeeEvents {
|
||||||
logrus.Infof("ServersAtEvent history not visible to us: %s", auth.HistoryVisibilityForRoom(events))
|
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
|
||||||
return nil, nil
|
return nil, visibility, nil
|
||||||
}
|
}
|
||||||
// get joined members
|
// get joined members
|
||||||
info, err := db.RoomInfo(ctx, roomID)
|
info, err := db.RoomInfo(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, visibility, nil
|
||||||
}
|
}
|
||||||
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, visibility, err
|
||||||
}
|
}
|
||||||
return db.Events(ctx, joinEventNIDs)
|
evs, err := db.Events(ctx, joinEventNIDs)
|
||||||
|
return evs, visibility, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
|
func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
|
||||||
|
|
|
||||||
|
|
@ -503,10 +503,11 @@ func (r *Queryer) QueryStateAndAuthChain(
|
||||||
}
|
}
|
||||||
|
|
||||||
var stateEvents []*gomatrixserverlib.Event
|
var stateEvents []*gomatrixserverlib.Event
|
||||||
stateEvents, rejected, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs)
|
stateEvents, rejected, stateMissing, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
response.StateKnown = !stateMissing
|
||||||
response.IsRejected = rejected
|
response.IsRejected = rejected
|
||||||
response.PrevEventsExist = true
|
response.PrevEventsExist = true
|
||||||
|
|
||||||
|
|
@ -542,15 +543,18 @@ func (r *Queryer) QueryStateAndAuthChain(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, error) {
|
// first bool: is rejected, second bool: state missing
|
||||||
|
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, bool, error) {
|
||||||
roomState := state.NewStateResolution(r.DB, roomInfo)
|
roomState := state.NewStateResolution(r.DB, roomInfo)
|
||||||
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case types.MissingEventError:
|
case types.MissingEventError:
|
||||||
return nil, false, nil
|
return nil, false, true, nil
|
||||||
|
case types.MissingStateError:
|
||||||
|
return nil, false, true, nil
|
||||||
default:
|
default:
|
||||||
return nil, false, err
|
return nil, false, false, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Currently only used on /state and /state_ids
|
// Currently only used on /state and /state_ids
|
||||||
|
|
@ -567,12 +571,11 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
|
||||||
ctx, prevStates,
|
ctx, prevStates,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, rejected, err
|
return nil, rejected, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
|
events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
|
||||||
|
return events, rejected, false, err
|
||||||
return events, rejected, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
|
type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ type Database interface {
|
||||||
// Look up the state entries for a list of string event IDs
|
// Look up the state entries for a list of string event IDs
|
||||||
// Returns an error if the there is an error talking to the database
|
// Returns an error if the there is an error talking to the database
|
||||||
// Returns a types.MissingEventError if the event IDs aren't in the database.
|
// Returns a types.MissingEventError if the event IDs aren't in the database.
|
||||||
StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
|
StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
|
||||||
// Look up the string event state keys for a list of numeric event state keys
|
// Look up the string event state keys for a list of numeric event state keys
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ const insertEventSQL = "" +
|
||||||
"INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" +
|
"INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
|
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
|
||||||
" ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" +
|
" ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" +
|
||||||
" SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = FALSE" +
|
" SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = TRUE" +
|
||||||
" RETURNING event_nid, state_snapshot_nid"
|
" RETURNING event_nid, state_snapshot_nid"
|
||||||
|
|
||||||
const selectEventSQL = "" +
|
const selectEventSQL = "" +
|
||||||
|
|
@ -88,6 +88,14 @@ const bulkSelectStateEventByIDSQL = "" +
|
||||||
" WHERE event_id = ANY($1)" +
|
" WHERE event_id = ANY($1)" +
|
||||||
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
|
||||||
|
// Bulk lookup of events by string ID that aren't excluded.
|
||||||
|
// Sort by the numeric IDs for event type and state key.
|
||||||
|
// This means we can use binary search to lookup entries by type and state key.
|
||||||
|
const bulkSelectStateEventByIDExcludingRejectedSQL = "" +
|
||||||
|
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
||||||
|
" WHERE event_id = ANY($1) AND is_rejected = FALSE" +
|
||||||
|
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
|
||||||
// Bulk look up of events by event NID, optionally filtering by the event type
|
// Bulk look up of events by event NID, optionally filtering by the event type
|
||||||
// or event state key NIDs if provided. (The CARDINALITY check will return true
|
// or event state key NIDs if provided. (The CARDINALITY check will return true
|
||||||
// if the provided arrays are empty, ergo no filtering).
|
// if the provided arrays are empty, ergo no filtering).
|
||||||
|
|
@ -140,23 +148,24 @@ const selectEventRejectedSQL = "" +
|
||||||
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"
|
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||||
bulkSelectStateEventByNIDStmt *sql.Stmt
|
bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByNIDStmt *sql.Stmt
|
||||||
updateEventStateStmt *sql.Stmt
|
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
||||||
selectEventSentToOutputStmt *sql.Stmt
|
updateEventStateStmt *sql.Stmt
|
||||||
updateEventSentToOutputStmt *sql.Stmt
|
selectEventSentToOutputStmt *sql.Stmt
|
||||||
selectEventIDStmt *sql.Stmt
|
updateEventSentToOutputStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
selectEventIDStmt *sql.Stmt
|
||||||
bulkSelectEventReferenceStmt *sql.Stmt
|
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
||||||
bulkSelectEventIDStmt *sql.Stmt
|
bulkSelectEventReferenceStmt *sql.Stmt
|
||||||
bulkSelectEventNIDStmt *sql.Stmt
|
bulkSelectEventIDStmt *sql.Stmt
|
||||||
bulkSelectUnsentEventNIDStmt *sql.Stmt
|
bulkSelectEventNIDStmt *sql.Stmt
|
||||||
selectMaxEventDepthStmt *sql.Stmt
|
bulkSelectUnsentEventNIDStmt *sql.Stmt
|
||||||
selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
selectMaxEventDepthStmt *sql.Stmt
|
||||||
selectEventRejectedStmt *sql.Stmt
|
selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
||||||
|
selectEventRejectedStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateEventsTable(db *sql.DB) error {
|
func CreateEventsTable(db *sql.DB) error {
|
||||||
|
|
@ -171,6 +180,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.selectEventStmt, selectEventSQL},
|
{&s.selectEventStmt, selectEventSQL},
|
||||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||||
|
{&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL},
|
||||||
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
|
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
|
||||||
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||||
{&s.updateEventStateStmt, updateEventStateSQL},
|
{&s.updateEventStateStmt, updateEventStateSQL},
|
||||||
|
|
@ -221,11 +231,18 @@ func (s *eventStatements) SelectEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If not excluding rejected events, and any of the requested events are missing from
|
||||||
|
// the database it returns a types.MissingEventError. If excluding rejected events,
|
||||||
|
// the events will be silently omitted without error.
|
||||||
func (s *eventStatements) BulkSelectStateEventByID(
|
func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
|
var stmt *sql.Stmt
|
||||||
|
if excludeRejected {
|
||||||
|
stmt = sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDExcludingRejectedStmt)
|
||||||
|
} else {
|
||||||
|
stmt = sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
|
||||||
|
}
|
||||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -235,10 +252,10 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
// because of the unique constraint on event IDs.
|
// because of the unique constraint on event IDs.
|
||||||
// So we can allocate an array of the correct size now.
|
// So we can allocate an array of the correct size now.
|
||||||
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
||||||
results := make([]types.StateEntry, len(eventIDs))
|
results := make([]types.StateEntry, 0, len(eventIDs))
|
||||||
i := 0
|
i := 0
|
||||||
for ; rows.Next(); i++ {
|
for ; rows.Next(); i++ {
|
||||||
result := &results[i]
|
var result types.StateEntry
|
||||||
if err = rows.Scan(
|
if err = rows.Scan(
|
||||||
&result.EventTypeNID,
|
&result.EventTypeNID,
|
||||||
&result.EventStateKeyNID,
|
&result.EventStateKeyNID,
|
||||||
|
|
@ -246,11 +263,12 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
results = append(results, result)
|
||||||
}
|
}
|
||||||
if err = rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if i != len(eventIDs) {
|
if !excludeRejected && i != len(eventIDs) {
|
||||||
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
||||||
// We don't know which ones were missing because we don't return the string IDs in the query.
|
// We don't know which ones were missing because we don't return the string IDs in the query.
|
||||||
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
||||||
|
|
@ -328,7 +346,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||||
// Genuine create events are the only case where it's OK to have no previous state.
|
// Genuine create events are the only case where it's OK to have no previous state.
|
||||||
isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1
|
isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1
|
||||||
if result.BeforeStateSnapshotNID == 0 && !isCreate {
|
if result.BeforeStateSnapshotNID == 0 && !isCreate {
|
||||||
return nil, types.MissingEventError(
|
return nil, types.MissingStateError(
|
||||||
fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
|
fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -113,9 +113,9 @@ func (d *Database) eventStateKeyNIDs(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateEntriesForEventIDs(
|
func (d *Database) StateEntriesForEventIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string, excludeRejected bool,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs)
|
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateEntriesForTuples(
|
func (d *Database) StateEntriesForTuples(
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ const insertEventSQL = `
|
||||||
INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)
|
INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
ON CONFLICT DO UPDATE
|
ON CONFLICT DO UPDATE
|
||||||
SET is_rejected = $8 WHERE is_rejected = 0
|
SET is_rejected = $8 WHERE is_rejected = 1
|
||||||
RETURNING event_nid, state_snapshot_nid;
|
RETURNING event_nid, state_snapshot_nid;
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|
@ -65,6 +65,14 @@ const bulkSelectStateEventByIDSQL = "" +
|
||||||
" WHERE event_id IN ($1)" +
|
" WHERE event_id IN ($1)" +
|
||||||
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
|
||||||
|
// Bulk lookup of events by string ID that aren't rejected.
|
||||||
|
// Sort by the numeric IDs for event type and state key.
|
||||||
|
// This means we can use binary search to lookup entries by type and state key.
|
||||||
|
const bulkSelectStateEventByIDExcludingRejectedSQL = "" +
|
||||||
|
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
||||||
|
" WHERE event_id IN ($1) AND is_rejected = 0" +
|
||||||
|
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
|
||||||
const bulkSelectStateEventByNIDSQL = "" +
|
const bulkSelectStateEventByNIDSQL = "" +
|
||||||
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
||||||
" WHERE event_nid IN ($1)"
|
" WHERE event_nid IN ($1)"
|
||||||
|
|
@ -113,19 +121,20 @@ const selectEventRejectedSQL = "" +
|
||||||
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"
|
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt
|
||||||
updateEventStateStmt *sql.Stmt
|
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
||||||
selectEventSentToOutputStmt *sql.Stmt
|
updateEventStateStmt *sql.Stmt
|
||||||
updateEventSentToOutputStmt *sql.Stmt
|
selectEventSentToOutputStmt *sql.Stmt
|
||||||
selectEventIDStmt *sql.Stmt
|
updateEventSentToOutputStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
selectEventIDStmt *sql.Stmt
|
||||||
bulkSelectEventReferenceStmt *sql.Stmt
|
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
||||||
bulkSelectEventIDStmt *sql.Stmt
|
bulkSelectEventReferenceStmt *sql.Stmt
|
||||||
selectEventRejectedStmt *sql.Stmt
|
bulkSelectEventIDStmt *sql.Stmt
|
||||||
|
selectEventRejectedStmt *sql.Stmt
|
||||||
//bulkSelectEventNIDStmt *sql.Stmt
|
//bulkSelectEventNIDStmt *sql.Stmt
|
||||||
//bulkSelectUnsentEventNIDStmt *sql.Stmt
|
//bulkSelectUnsentEventNIDStmt *sql.Stmt
|
||||||
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
||||||
|
|
@ -145,6 +154,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.selectEventStmt, selectEventSQL},
|
{&s.selectEventStmt, selectEventSQL},
|
||||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||||
|
{&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL},
|
||||||
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||||
{&s.updateEventStateStmt, updateEventStateSQL},
|
{&s.updateEventStateStmt, updateEventStateSQL},
|
||||||
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
|
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
|
||||||
|
|
@ -194,16 +204,24 @@ func (s *eventStatements) SelectEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If not excluding rejected events, and any of the requested events are missing from
|
||||||
|
// the database it returns a types.MissingEventError. If excluding rejected events,
|
||||||
|
// the events will be silently omitted without error.
|
||||||
func (s *eventStatements) BulkSelectStateEventByID(
|
func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
///////////////
|
///////////////
|
||||||
|
var sql string
|
||||||
|
if excludeRejected {
|
||||||
|
sql = bulkSelectStateEventByIDExcludingRejectedSQL
|
||||||
|
} else {
|
||||||
|
sql = bulkSelectStateEventByIDSQL
|
||||||
|
}
|
||||||
iEventIDs := make([]interface{}, len(eventIDs))
|
iEventIDs := make([]interface{}, len(eventIDs))
|
||||||
for k, v := range eventIDs {
|
for k, v := range eventIDs {
|
||||||
iEventIDs[k] = v
|
iEventIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
selectOrig := strings.Replace(sql, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -221,10 +239,10 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
// because of the unique constraint on event IDs.
|
// because of the unique constraint on event IDs.
|
||||||
// So we can allocate an array of the correct size now.
|
// So we can allocate an array of the correct size now.
|
||||||
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
||||||
results := make([]types.StateEntry, len(eventIDs))
|
results := make([]types.StateEntry, 0, len(eventIDs))
|
||||||
i := 0
|
i := 0
|
||||||
for ; rows.Next(); i++ {
|
for ; rows.Next(); i++ {
|
||||||
result := &results[i]
|
var result types.StateEntry
|
||||||
if err = rows.Scan(
|
if err = rows.Scan(
|
||||||
&result.EventTypeNID,
|
&result.EventTypeNID,
|
||||||
&result.EventStateKeyNID,
|
&result.EventStateKeyNID,
|
||||||
|
|
@ -232,8 +250,9 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
results = append(results, result)
|
||||||
}
|
}
|
||||||
if i != len(eventIDs) {
|
if !excludeRejected && i != len(eventIDs) {
|
||||||
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
||||||
// We don't know which ones were missing because we don't return the string IDs in the query.
|
// We don't know which ones were missing because we don't return the string IDs in the query.
|
||||||
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
||||||
|
|
@ -343,7 +362,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||||
// Genuine create events are the only case where it's OK to have no previous state.
|
// Genuine create events are the only case where it's OK to have no previous state.
|
||||||
isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1
|
isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1
|
||||||
if result.BeforeStateSnapshotNID == 0 && !isCreate {
|
if result.BeforeStateSnapshotNID == 0 && !isCreate {
|
||||||
return nil, types.MissingEventError(
|
return nil, types.MissingStateError(
|
||||||
fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
|
fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ func Test_EventsTable(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEvents, err := tab.BulkSelectStateEventByID(ctx, nil, eventIDs)
|
stateEvents, err := tab.BulkSelectStateEventByID(ctx, nil, eventIDs, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, len(stateEvents), len(eventIDs))
|
assert.Equal(t, len(stateEvents), len(eventIDs))
|
||||||
nids := make([]types.EventNID, 0, len(stateEvents))
|
nids := make([]types.EventNID, 0, len(stateEvents))
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ type Events interface {
|
||||||
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
|
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error)
|
BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
|
||||||
BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
|
BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
|
||||||
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||||
|
|
|
||||||
|
|
@ -25,21 +25,23 @@ import (
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
sentryhttp "github.com/getsentry/sentry-go/http"
|
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
|
||||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/h2c"
|
"golang.org/x/net/http2/h2c"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
|
|
@ -47,6 +49,8 @@ import (
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/kardianos/minwinsvc"
|
"github.com/kardianos/minwinsvc"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
asinthttp "github.com/matrix-org/dendrite/appservice/inthttp"
|
asinthttp "github.com/matrix-org/dendrite/appservice/inthttp"
|
||||||
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
|
|
@ -58,7 +62,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
|
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaseDendrite is a base for creating new instances of dendrite. It parses
|
// BaseDendrite is a base for creating new instances of dendrite. It parses
|
||||||
|
|
@ -87,6 +90,7 @@ type BaseDendrite struct {
|
||||||
Database *sql.DB
|
Database *sql.DB
|
||||||
DatabaseWriter sqlutil.Writer
|
DatabaseWriter sqlutil.Writer
|
||||||
EnableMetrics bool
|
EnableMetrics bool
|
||||||
|
startupLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
const NoListener = ""
|
const NoListener = ""
|
||||||
|
|
@ -394,6 +398,9 @@ func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
internalHTTPAddr, externalHTTPAddr config.HTTPAddress,
|
internalHTTPAddr, externalHTTPAddr config.HTTPAddress,
|
||||||
certFile, keyFile *string,
|
certFile, keyFile *string,
|
||||||
) {
|
) {
|
||||||
|
// Manually unlocked right before actually serving requests,
|
||||||
|
// as we don't return from this method (defer doesn't work).
|
||||||
|
b.startupLock.Lock()
|
||||||
internalAddr, _ := internalHTTPAddr.Address()
|
internalAddr, _ := internalHTTPAddr.Address()
|
||||||
externalAddr, _ := externalHTTPAddr.Address()
|
externalAddr, _ := externalHTTPAddr.Address()
|
||||||
|
|
||||||
|
|
@ -472,6 +479,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux)
|
externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux)
|
||||||
externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux)
|
externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux)
|
||||||
|
|
||||||
|
b.startupLock.Unlock()
|
||||||
if internalAddr != NoListener && internalAddr != externalAddr {
|
if internalAddr != NoListener && internalAddr != externalAddr {
|
||||||
go func() {
|
go func() {
|
||||||
var internalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
|
var internalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
|
||||||
|
|
|
||||||
|
|
@ -224,12 +224,7 @@ func loadConfig(
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath)
|
privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath)
|
||||||
privateKeyData, err := readFile(privateKeyPath)
|
if c.Global.KeyID, c.Global.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Global.KeyID, c.Global.PrivateKey, err = readKeyPEM(privateKeyPath, privateKeyData, true); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -265,6 +260,14 @@ func loadConfig(
|
||||||
return &c, nil
|
return &c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LoadMatrixKey(privateKeyPath string, readFile func(string) ([]byte, error)) (gomatrixserverlib.KeyID, ed25519.PrivateKey, error) {
|
||||||
|
privateKeyData, err := readFile(privateKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return readKeyPEM(privateKeyPath, privateKeyData, true)
|
||||||
|
}
|
||||||
|
|
||||||
// Derive generates data that is derived from various values provided in
|
// Derive generates data that is derived from various values provided in
|
||||||
// the config file.
|
// the config file.
|
||||||
func (config *Dendrite) Derive() error {
|
func (config *Dendrite) Derive() error {
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,6 @@ type AppServiceAPI struct {
|
||||||
|
|
||||||
InternalAPI InternalAPIOptions `yaml:"internal_api"`
|
InternalAPI InternalAPIOptions `yaml:"internal_api"`
|
||||||
|
|
||||||
Database DatabaseOptions `yaml:"database"`
|
|
||||||
|
|
||||||
// DisableTLSValidation disables the validation of X.509 TLS certs
|
// DisableTLSValidation disables the validation of X.509 TLS certs
|
||||||
// on appservice endpoints. This is not recommended in production!
|
// on appservice endpoints. This is not recommended in production!
|
||||||
DisableTLSValidation bool `yaml:"disable_tls_validation"`
|
DisableTLSValidation bool `yaml:"disable_tls_validation"`
|
||||||
|
|
@ -43,16 +41,9 @@ type AppServiceAPI struct {
|
||||||
func (c *AppServiceAPI) Defaults(generate bool) {
|
func (c *AppServiceAPI) Defaults(generate bool) {
|
||||||
c.InternalAPI.Listen = "http://localhost:7777"
|
c.InternalAPI.Listen = "http://localhost:7777"
|
||||||
c.InternalAPI.Connect = "http://localhost:7777"
|
c.InternalAPI.Connect = "http://localhost:7777"
|
||||||
c.Database.Defaults(5)
|
|
||||||
if generate {
|
|
||||||
c.Database.ConnectionString = "file:appservice.db"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
if c.Matrix.DatabaseOptions.ConnectionString == "" {
|
|
||||||
checkNotEmpty(configErrs, "app_service_api.database.connection_string", string(c.Database.ConnectionString))
|
|
||||||
}
|
|
||||||
if isMonolith { // polylith required configs below
|
if isMonolith { // polylith required configs below
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,16 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// JetStreamConsumer starts a durable consumer on the given subject with the
|
||||||
|
// given durable name. The function will be called when one or more messages
|
||||||
|
// is available, up to the maximum batch size specified. If the batch is set to
|
||||||
|
// 1 then messages will be delivered one at a time. If the function is called,
|
||||||
|
// the messages array is guaranteed to be at least 1 in size. Any provided NATS
|
||||||
|
// options will be passed through to the pull subscriber creation. The consumer
|
||||||
|
// will continue to run until the context expires, at which point it will stop.
|
||||||
func JetStreamConsumer(
|
func JetStreamConsumer(
|
||||||
ctx context.Context, js nats.JetStreamContext, subj, durable string,
|
ctx context.Context, js nats.JetStreamContext, subj, durable string, batch int,
|
||||||
f func(ctx context.Context, msg *nats.Msg) bool,
|
f func(ctx context.Context, msgs []*nats.Msg) bool,
|
||||||
opts ...nats.SubOpt,
|
opts ...nats.SubOpt,
|
||||||
) error {
|
) error {
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|
@ -50,7 +57,7 @@ func JetStreamConsumer(
|
||||||
// enforce its own deadline (roughly 5 seconds by default). Therefore
|
// enforce its own deadline (roughly 5 seconds by default). Therefore
|
||||||
// it is our responsibility to check whether our context expired or
|
// it is our responsibility to check whether our context expired or
|
||||||
// not when a context error is returned. Footguns. Footguns everywhere.
|
// not when a context error is returned. Footguns. Footguns everywhere.
|
||||||
msgs, err := sub.Fetch(1, nats.Context(ctx))
|
msgs, err := sub.Fetch(batch, nats.Context(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == context.Canceled || err == context.DeadlineExceeded {
|
if err == context.Canceled || err == context.DeadlineExceeded {
|
||||||
// Work out whether it was the JetStream context that expired
|
// Work out whether it was the JetStream context that expired
|
||||||
|
|
@ -74,21 +81,26 @@ func JetStreamConsumer(
|
||||||
if len(msgs) < 1 {
|
if len(msgs) < 1 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
msg := msgs[0]
|
for _, msg := range msgs {
|
||||||
if err = msg.InProgress(nats.Context(ctx)); err != nil {
|
if err = msg.InProgress(nats.Context(ctx)); err != nil {
|
||||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
|
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
|
||||||
sentry.CaptureException(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if f(ctx, msg) {
|
|
||||||
if err = msg.AckSync(nats.Context(ctx)); err != nil {
|
|
||||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
|
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if f(ctx, msgs) {
|
||||||
|
for _, msg := range msgs {
|
||||||
|
if err = msg.AckSync(nats.Context(ctx)); err != nil {
|
||||||
|
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err = msg.Nak(nats.Context(ctx)); err != nil {
|
for _, msg := range msgs {
|
||||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
|
if err = msg.Nak(nats.Context(ctx)); err != nil {
|
||||||
sentry.CaptureException(err)
|
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,7 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
|
||||||
OutputReceiptEvent: {"SyncAPIEDUServerReceiptConsumer", "FederationAPIEDUServerConsumer"},
|
OutputReceiptEvent: {"SyncAPIEDUServerReceiptConsumer", "FederationAPIEDUServerConsumer"},
|
||||||
OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"},
|
OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"},
|
||||||
OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"},
|
OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"},
|
||||||
|
OutputRoomEvent: {"AppserviceRoomserverConsumer"},
|
||||||
} {
|
} {
|
||||||
streamName := cfg.Matrix.JetStream.Prefixed(stream)
|
streamName := cfg.Matrix.JetStream.Prefixed(stream)
|
||||||
for _, consumer := range consumers {
|
for _, consumer := range consumers {
|
||||||
|
|
|
||||||
|
|
@ -75,15 +75,16 @@ func NewOutputClientDataConsumer(
|
||||||
// Start consuming from room servers
|
// Start consuming from room servers
|
||||||
func (s *OutputClientDataConsumer) Start() error {
|
func (s *OutputClientDataConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage is called when the sync server receives a new event from the client API server output log.
|
// onMessage is called when the sync server receives a new event from the client API server output log.
|
||||||
// It is not safe for this function to be called from multiple goroutines, or else the
|
// It is not safe for this function to be called from multiple goroutines, or else the
|
||||||
// sync stream position may race and be incorrectly calculated.
|
// sync stream position may race and be incorrectly calculated.
|
||||||
func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
// Parse out the event JSON
|
// Parse out the event JSON
|
||||||
userID := msg.Header.Get(jetstream.UserID)
|
userID := msg.Header.Get(jetstream.UserID)
|
||||||
var output eventutil.AccountData
|
var output eventutil.AccountData
|
||||||
|
|
|
||||||
|
|
@ -75,12 +75,13 @@ func NewOutputKeyChangeEventConsumer(
|
||||||
// Start consuming from the key server
|
// Start consuming from the key server
|
||||||
func (s *OutputKeyChangeEventConsumer) Start() error {
|
func (s *OutputKeyChangeEventConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
var m api.DeviceMessage
|
var m api.DeviceMessage
|
||||||
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
||||||
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
||||||
|
|
|
||||||
|
|
@ -128,12 +128,13 @@ func (s *PresenceConsumer) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.presenceTopic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.presenceTopic, s.durable, 1, s.onMessage,
|
||||||
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
|
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *PresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
userID := msg.Header.Get(jetstream.UserID)
|
userID := msg.Header.Get(jetstream.UserID)
|
||||||
presence := msg.Header.Get("presence")
|
presence := msg.Header.Get("presence")
|
||||||
timestamp := msg.Header.Get("last_active_ts")
|
timestamp := msg.Header.Get("last_active_ts")
|
||||||
|
|
|
||||||
|
|
@ -74,12 +74,13 @@ func NewOutputReceiptEventConsumer(
|
||||||
// Start consuming receipts events.
|
// Start consuming receipts events.
|
||||||
func (s *OutputReceiptEventConsumer) Start() error {
|
func (s *OutputReceiptEventConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
output := types.OutputReceiptEvent{
|
output := types.OutputReceiptEvent{
|
||||||
UserID: msg.Header.Get(jetstream.UserID),
|
UserID: msg.Header.Get(jetstream.UserID),
|
||||||
RoomID: msg.Header.Get(jetstream.RoomID),
|
RoomID: msg.Header.Get(jetstream.RoomID),
|
||||||
|
|
|
||||||
|
|
@ -79,15 +79,16 @@ func NewOutputRoomEventConsumer(
|
||||||
// Start consuming from room servers
|
// Start consuming from room servers
|
||||||
func (s *OutputRoomEventConsumer) Start() error {
|
func (s *OutputRoomEventConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage is called when the sync server receives a new event from the room server output log.
|
// onMessage is called when the sync server receives a new event from the room server output log.
|
||||||
// It is not safe for this function to be called from multiple goroutines, or else the
|
// It is not safe for this function to be called from multiple goroutines, or else the
|
||||||
// sync stream position may race and be incorrectly calculated.
|
// sync stream position may race and be incorrectly calculated.
|
||||||
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
// Parse out the event JSON
|
// Parse out the event JSON
|
||||||
var err error
|
var err error
|
||||||
var output api.OutputEvent
|
var output api.OutputEvent
|
||||||
|
|
|
||||||
|
|
@ -68,12 +68,13 @@ func NewOutputSendToDeviceEventConsumer(
|
||||||
// Start consuming send-to-device events.
|
// Start consuming send-to-device events.
|
||||||
func (s *OutputSendToDeviceEventConsumer) Start() error {
|
func (s *OutputSendToDeviceEventConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
userID := msg.Header.Get(jetstream.UserID)
|
userID := msg.Header.Get(jetstream.UserID)
|
||||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -64,12 +64,13 @@ func NewOutputTypingEventConsumer(
|
||||||
// Start consuming typing events.
|
// Start consuming typing events.
|
||||||
func (s *OutputTypingEventConsumer) Start() error {
|
func (s *OutputTypingEventConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
roomID := msg.Header.Get(jetstream.RoomID)
|
roomID := msg.Header.Get(jetstream.RoomID)
|
||||||
userID := msg.Header.Get(jetstream.UserID)
|
userID := msg.Header.Get(jetstream.UserID)
|
||||||
typing, err := strconv.ParseBool(msg.Header.Get("typing"))
|
typing, err := strconv.ParseBool(msg.Header.Get("typing"))
|
||||||
|
|
|
||||||
|
|
@ -67,8 +67,8 @@ func NewOutputNotificationDataConsumer(
|
||||||
// Start starts consumption.
|
// Start starts consumption.
|
||||||
func (s *OutputNotificationDataConsumer) Start() error {
|
func (s *OutputNotificationDataConsumer) Start() error {
|
||||||
return jetstream.JetStreamConsumer(
|
return jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -76,7 +76,8 @@ func (s *OutputNotificationDataConsumer) Start() error {
|
||||||
// the push server. It is not safe for this function to be called from
|
// the push server. It is not safe for this function to be called from
|
||||||
// multiple goroutines, or else the sync stream position may race and
|
// multiple goroutines, or else the sync stream position may race and
|
||||||
// be incorrectly calculated.
|
// be incorrectly calculated.
|
||||||
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
userID := string(msg.Header.Get(jetstream.UserID))
|
userID := string(msg.Header.Get(jetstream.UserID))
|
||||||
|
|
||||||
// Parse out the event JSON
|
// Parse out the event JSON
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,16 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
keytypes "github.com/matrix-org/dendrite/keyserver/types"
|
keytypes "github.com/matrix-org/dendrite/keyserver/types"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/matrix-org/util"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DeviceOTKCounts adds one-time key counts to the /sync response
|
// DeviceOTKCounts adds one-time key counts to the /sync response
|
||||||
|
|
@ -125,7 +127,7 @@ func DeviceListCatchup(
|
||||||
"from": offset,
|
"from": offset,
|
||||||
"to": toOffset,
|
"to": toOffset,
|
||||||
"response_offset": queryRes.Offset,
|
"response_offset": queryRes.Offset,
|
||||||
}).Debugf("QueryKeyChanges request result: %+v", res.DeviceLists)
|
}).Tracef("QueryKeyChanges request result: %+v", res.DeviceLists)
|
||||||
|
|
||||||
return types.StreamPosition(queryRes.Offset), hasNew, nil
|
return types.StreamPosition(queryRes.Offset), hasNew, nil
|
||||||
}
|
}
|
||||||
|
|
@ -277,6 +279,10 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
|
||||||
// it's enough to know that we have our member event here, don't need to check membership content
|
// it's enough to know that we have our member event here, don't need to check membership content
|
||||||
// as it's implied by being in the respective section of the sync response.
|
// as it's implied by being in the respective section of the sync response.
|
||||||
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
|
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
|
||||||
|
// ignore e.g. join -> join changes
|
||||||
|
if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -350,8 +350,10 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages")
|
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages")
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"duration": time.Since(startTime),
|
"duration": time.Since(startTime),
|
||||||
"room_id": r.roomID,
|
"room_id": r.roomID,
|
||||||
|
"events_before": len(events),
|
||||||
|
"events_after": len(filteredEvents),
|
||||||
}).Debug("applied history visibility (messages)")
|
}).Debug("applied history visibility (messages)")
|
||||||
return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err
|
return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err
|
||||||
}
|
}
|
||||||
|
|
@ -513,6 +515,9 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
|
||||||
|
|
||||||
// Store the events in the database, while marking them as unfit to show
|
// Store the events in the database, while marking them as unfit to show
|
||||||
// up in responses to sync requests.
|
// up in responses to sync requests.
|
||||||
|
if res.HistoryVisibility == "" {
|
||||||
|
res.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
|
||||||
|
}
|
||||||
for i := range res.Events {
|
for i := range res.Events {
|
||||||
_, err = r.db.WriteEvent(
|
_, err = r.db.WriteEvent(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
|
|
@ -521,7 +526,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
|
||||||
[]string{},
|
[]string{},
|
||||||
[]string{},
|
[]string{},
|
||||||
nil, true,
|
nil, true,
|
||||||
gomatrixserverlib.HistoryVisibilityShared,
|
res.HistoryVisibility,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -534,6 +539,9 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
|
||||||
// last `limit` events
|
// last `limit` events
|
||||||
events = events[len(events)-limit:]
|
events = events[len(events)-limit:]
|
||||||
}
|
}
|
||||||
|
for _, ev := range events {
|
||||||
|
ev.Visibility = res.HistoryVisibility
|
||||||
|
}
|
||||||
|
|
||||||
return events, nil
|
return events, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,11 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,8 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
|
||||||
-- The event content JSON.
|
-- The event content JSON.
|
||||||
content TEXT NOT NULL
|
content TEXT NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS syncapi_send_to_device_user_id_device_id_idx ON syncapi_send_to_device(user_id, device_id);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertSendToDeviceMessageSQL = `
|
const insertSendToDeviceMessageSQL = `
|
||||||
|
|
|
||||||
|
|
@ -20,15 +20,18 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
|
// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
|
||||||
|
|
@ -683,7 +686,7 @@ func (d *Database) GetStateDeltas(
|
||||||
ctx context.Context, device *userapi.Device,
|
ctx context.Context, device *userapi.Device,
|
||||||
r types.Range, userID string,
|
r types.Range, userID string,
|
||||||
stateFilter *gomatrixserverlib.StateFilter,
|
stateFilter *gomatrixserverlib.StateFilter,
|
||||||
) ([]types.StateDelta, []string, error) {
|
) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
|
||||||
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
|
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
|
||||||
// - Get membership list changes for this user in this sync response
|
// - Get membership list changes for this user in this sync response
|
||||||
// - For each room which has membership list changes:
|
// - For each room which has membership list changes:
|
||||||
|
|
@ -718,8 +721,6 @@ func (d *Database) GetStateDeltas(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var deltas []types.StateDelta
|
|
||||||
|
|
||||||
// get all the state events ever (i.e. for all available rooms) between these two positions
|
// get all the state events ever (i.e. for all available rooms) between these two positions
|
||||||
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
|
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -767,15 +768,11 @@ func (d *Database) GetStateDeltas(
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle newly joined rooms and non-joined rooms
|
// handle newly joined rooms and non-joined rooms
|
||||||
|
newlyJoinedRooms := make(map[string]bool, len(state))
|
||||||
for roomID, stateStreamEvents := range state {
|
for roomID, stateStreamEvents := range state {
|
||||||
for _, ev := range stateStreamEvents {
|
for _, ev := range stateStreamEvents {
|
||||||
// TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event.
|
if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" {
|
||||||
// We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this,
|
if membership == gomatrixserverlib.Join && prevMembership != membership {
|
||||||
// dupe join events will result in the entire room state coming down to the client again. This is added in
|
|
||||||
// the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
|
|
||||||
// the timeline.
|
|
||||||
if membership := getMembershipFromEvent(ev.Event, userID); membership != "" {
|
|
||||||
if membership == gomatrixserverlib.Join {
|
|
||||||
// send full room state down instead of a delta
|
// send full room state down instead of a delta
|
||||||
var s []types.StreamEvent
|
var s []types.StreamEvent
|
||||||
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
|
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
|
||||||
|
|
@ -786,6 +783,7 @@ func (d *Database) GetStateDeltas(
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
state[roomID] = s
|
state[roomID] = s
|
||||||
|
newlyJoinedRooms[roomID] = true
|
||||||
continue // we'll add this room in when we do joined rooms
|
continue // we'll add this room in when we do joined rooms
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -806,6 +804,7 @@ func (d *Database) GetStateDeltas(
|
||||||
Membership: gomatrixserverlib.Join,
|
Membership: gomatrixserverlib.Join,
|
||||||
StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
|
StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
|
||||||
RoomID: joinedRoomID,
|
RoomID: joinedRoomID,
|
||||||
|
NewlyJoined: newlyJoinedRooms[joinedRoomID],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -892,7 +891,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
||||||
|
|
||||||
for roomID, stateStreamEvents := range state {
|
for roomID, stateStreamEvents := range state {
|
||||||
for _, ev := range stateStreamEvents {
|
for _, ev := range stateStreamEvents {
|
||||||
if membership := getMembershipFromEvent(ev.Event, userID); membership != "" {
|
if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" {
|
||||||
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
|
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
|
||||||
deltas[roomID] = types.StateDelta{
|
deltas[roomID] = types.StateDelta{
|
||||||
Membership: membership,
|
Membership: membership,
|
||||||
|
|
@ -1003,15 +1002,16 @@ func (d *Database) CleanSendToDeviceUpdates(
|
||||||
|
|
||||||
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
|
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
|
||||||
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
|
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
|
||||||
func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string {
|
func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) {
|
||||||
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
|
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
|
||||||
return ""
|
return "", ""
|
||||||
}
|
}
|
||||||
membership, err := ev.Membership()
|
membership, err := ev.Membership()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return "", ""
|
||||||
}
|
}
|
||||||
return membership
|
prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str
|
||||||
|
return membership, prevMembership
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreReceipt stores user receipts
|
// StoreReceipt stores user receipts
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,8 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
|
||||||
-- The event content JSON.
|
-- The event content JSON.
|
||||||
content TEXT NOT NULL
|
content TEXT NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS syncapi_send_to_device_user_id_device_id_idx ON syncapi_send_to_device(user_id, device_id);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertSendToDeviceMessageSQL = `
|
const insertSendToDeviceMessageSQL = `
|
||||||
|
|
|
||||||
|
|
@ -178,24 +178,24 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var stateDeltas []types.StateDelta
|
var stateDeltas []types.StateDelta
|
||||||
var joinedRooms []string
|
var syncJoinedRooms []string
|
||||||
|
|
||||||
stateFilter := req.Filter.Room.State
|
stateFilter := req.Filter.Room.State
|
||||||
eventFilter := req.Filter.Room.Timeline
|
eventFilter := req.Filter.Room.Timeline
|
||||||
|
|
||||||
if req.WantFullState {
|
if req.WantFullState {
|
||||||
if stateDeltas, joinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
|
if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
|
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if stateDeltas, joinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
|
if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
|
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, roomID := range joinedRooms {
|
for _, roomID := range syncJoinedRooms {
|
||||||
req.Rooms[roomID] = gomatrixserverlib.Join
|
req.Rooms[roomID] = gomatrixserverlib.Join
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -209,11 +209,27 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
|
|
||||||
newPos = from
|
newPos = from
|
||||||
for _, delta := range stateDeltas {
|
for _, delta := range stateDeltas {
|
||||||
|
newRange := r
|
||||||
|
// If this room was joined in this sync, try to fetch
|
||||||
|
// as much timeline events as allowed by the filter.
|
||||||
|
if delta.NewlyJoined {
|
||||||
|
// Reverse the range, so we get the most recent first.
|
||||||
|
// This will be limited by the eventFilter.
|
||||||
|
newRange = types.Range{
|
||||||
|
From: r.To,
|
||||||
|
To: 0,
|
||||||
|
Backwards: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
var pos types.StreamPosition
|
var pos types.StreamPosition
|
||||||
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil {
|
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
|
||||||
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
||||||
return to
|
return to
|
||||||
}
|
}
|
||||||
|
// Reset the position, as it is only for the special case of newly joined rooms
|
||||||
|
if delta.NewlyJoined {
|
||||||
|
pos = newRange.From
|
||||||
|
}
|
||||||
switch {
|
switch {
|
||||||
case r.Backwards && pos < newPos:
|
case r.Backwards && pos < newPos:
|
||||||
fallthrough
|
fallthrough
|
||||||
|
|
@ -287,7 +303,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
|
|
||||||
if stateFilter.LazyLoadMembers {
|
if stateFilter.LazyLoadMembers {
|
||||||
delta.StateEvents, err = p.lazyLoadMembers(
|
delta.StateEvents, err = p.lazyLoadMembers(
|
||||||
ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers,
|
ctx, delta.RoomID, true, limited, stateFilter,
|
||||||
device, recentEvents, delta.StateEvents,
|
device, recentEvents, delta.StateEvents,
|
||||||
)
|
)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
|
@ -309,12 +325,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
logrus.WithError(err).Error("unable to apply history visibility filter")
|
logrus.WithError(err).Error("unable to apply history visibility filter")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(events) > 0 {
|
|
||||||
updateLatestPosition(events[len(events)-1].EventID())
|
|
||||||
}
|
|
||||||
if len(delta.StateEvents) > 0 {
|
if len(delta.StateEvents) > 0 {
|
||||||
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
|
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
|
||||||
}
|
}
|
||||||
|
if len(events) > 0 {
|
||||||
|
updateLatestPosition(events[len(events)-1].EventID())
|
||||||
|
}
|
||||||
|
|
||||||
switch delta.Membership {
|
switch delta.Membership {
|
||||||
case gomatrixserverlib.Join:
|
case gomatrixserverlib.Join:
|
||||||
|
|
@ -387,6 +403,8 @@ func applyHistoryVisibilityFilter(
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"duration": time.Since(startTime),
|
"duration": time.Since(startTime),
|
||||||
"room_id": roomID,
|
"room_id": roomID,
|
||||||
|
"before": len(recentEvents),
|
||||||
|
"after": len(events),
|
||||||
}).Debug("applied history visibility (sync)")
|
}).Debug("applied history visibility (sync)")
|
||||||
return events, nil
|
return events, nil
|
||||||
}
|
}
|
||||||
|
|
@ -514,7 +532,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
|
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
|
||||||
false, limited, stateFilter.IncludeRedundantMembers,
|
false, limited, stateFilter,
|
||||||
device, recentEvents, stateEvents,
|
device, recentEvents, stateEvents,
|
||||||
)
|
)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
|
@ -533,7 +551,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
|
|
||||||
func (p *PDUStreamProvider) lazyLoadMembers(
|
func (p *PDUStreamProvider) lazyLoadMembers(
|
||||||
ctx context.Context, roomID string,
|
ctx context.Context, roomID string,
|
||||||
incremental, limited, includeRedundant bool,
|
incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter,
|
||||||
device *userapi.Device,
|
device *userapi.Device,
|
||||||
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
|
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
|
||||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
|
@ -563,7 +581,7 @@ func (p *PDUStreamProvider) lazyLoadMembers(
|
||||||
stateKey := *event.StateKey()
|
stateKey := *event.StateKey()
|
||||||
if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental {
|
if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental {
|
||||||
newStateEvents = append(newStateEvents, event)
|
newStateEvents = append(newStateEvents, event)
|
||||||
if !includeRedundant {
|
if !stateFilter.IncludeRedundantMembers {
|
||||||
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID())
|
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID())
|
||||||
}
|
}
|
||||||
delete(timelineUsers, stateKey)
|
delete(timelineUsers, stateKey)
|
||||||
|
|
@ -578,6 +596,7 @@ func (p *PDUStreamProvider) lazyLoadMembers(
|
||||||
}
|
}
|
||||||
// Query missing membership events
|
// Query missing membership events
|
||||||
filter := gomatrixserverlib.DefaultStateFilter()
|
filter := gomatrixserverlib.DefaultStateFilter()
|
||||||
|
filter.Limit = stateFilter.Limit
|
||||||
filter.Senders = &wantUsers
|
filter.Senders = &wantUsers
|
||||||
filter.Types = &[]string{gomatrixserverlib.MRoomMember}
|
filter.Types = &[]string{gomatrixserverlib.MRoomMember}
|
||||||
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter)
|
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter)
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,11 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type PresenceStreamProvider struct {
|
type PresenceStreamProvider struct {
|
||||||
|
|
@ -175,6 +177,10 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
|
||||||
// it's enough to know that we have our member event here, don't need to check membership content
|
// it's enough to know that we have our member event here, don't need to check membership content
|
||||||
// as it's implied by being in the respective section of the sync response.
|
// as it's implied by being in the respective section of the sync response.
|
||||||
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
|
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
|
||||||
|
// ignore e.g. join -> join changes
|
||||||
|
if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,12 +23,13 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultSyncTimeout = time.Duration(0)
|
const defaultSyncTimeout = time.Duration(0)
|
||||||
|
|
@ -46,15 +47,9 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: read from stored filters too
|
|
||||||
|
// Create a default filter and apply a stored filter on top of it (if specified)
|
||||||
filter := gomatrixserverlib.DefaultFilter()
|
filter := gomatrixserverlib.DefaultFilter()
|
||||||
if since.IsEmpty() {
|
|
||||||
// Send as much account data down for complete syncs as possible
|
|
||||||
// by default, otherwise clients do weird things while waiting
|
|
||||||
// for the rest of the data to trickle down.
|
|
||||||
filter.AccountData.Limit = math.MaxInt32
|
|
||||||
filter.Room.AccountData.Limit = math.MaxInt32
|
|
||||||
}
|
|
||||||
filterQuery := req.URL.Query().Get("filter")
|
filterQuery := req.URL.Query().Get("filter")
|
||||||
if filterQuery != "" {
|
if filterQuery != "" {
|
||||||
if filterQuery[0] == '{' {
|
if filterQuery[0] == '{' {
|
||||||
|
|
@ -76,6 +71,17 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A loaded filter might have overwritten these values,
|
||||||
|
// so set them after loading the filter.
|
||||||
|
if since.IsEmpty() {
|
||||||
|
// Send as much account data down for complete syncs as possible
|
||||||
|
// by default, otherwise clients do weird things while waiting
|
||||||
|
// for the rest of the data to trickle down.
|
||||||
|
filter.AccountData.Limit = math.MaxInt32
|
||||||
|
filter.Room.AccountData.Limit = math.MaxInt32
|
||||||
|
filter.Room.State.Limit = math.MaxInt32
|
||||||
|
}
|
||||||
|
|
||||||
logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
||||||
"user_id": device.UserID,
|
"user_id": device.UserID,
|
||||||
"device_id": device.ID,
|
"device_id": device.ID,
|
||||||
|
|
|
||||||
|
|
@ -298,8 +298,8 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
return giveup()
|
return giveup()
|
||||||
|
|
||||||
case <-userStreamListener.GetNotifyChannel(syncReq.Since):
|
case <-userStreamListener.GetNotifyChannel(syncReq.Since):
|
||||||
syncReq.Log.Debugln("Responding to sync after wake-up")
|
|
||||||
currentPos.ApplyUpdates(userStreamListener.GetSyncPosition())
|
currentPos.ApplyUpdates(userStreamListener.GetSyncPosition())
|
||||||
|
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync after wake-up")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
|
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
|
||||||
|
|
|
||||||
|
|
@ -154,8 +154,12 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
||||||
wantJoinedRooms: []string{room.ID},
|
wantJoinedRooms: []string{room.ID},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// TODO: find a better way
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool {
|
||||||
|
// wait for the last sent eventID to come down sync
|
||||||
|
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID())
|
||||||
|
return gjson.Get(syncBody, path).Exists()
|
||||||
|
})
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
@ -191,6 +195,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
||||||
|
t.Skip("Skipped, possibly fixed")
|
||||||
user := test.NewUser(t)
|
user := test.NewUser(t)
|
||||||
room := test.NewRoom(t, user)
|
room := test.NewRoom(t, user)
|
||||||
alice := userapi.Device{
|
alice := userapi.Device{
|
||||||
|
|
@ -343,6 +348,13 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
|
||||||
|
|
||||||
// create the users
|
// create the users
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
|
aliceDev := userapi.Device{
|
||||||
|
ID: "ALICEID",
|
||||||
|
UserID: alice.ID,
|
||||||
|
AccessToken: "ALICE_BEARER_TOKEN",
|
||||||
|
DisplayName: "ALICE",
|
||||||
|
}
|
||||||
|
|
||||||
bob := test.NewUser(t)
|
bob := test.NewUser(t)
|
||||||
|
|
||||||
bobDev := userapi.Device{
|
bobDev := userapi.Device{
|
||||||
|
|
@ -409,7 +421,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
|
||||||
rsAPI := roomserver.NewInternalAPI(base)
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
rsAPI.SetFederationAPI(nil, nil)
|
rsAPI.SetFederationAPI(nil, nil)
|
||||||
|
|
||||||
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, rsAPI, &syncKeyAPI{})
|
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{})
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
|
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
|
||||||
|
|
@ -418,12 +430,18 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
|
||||||
room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility))
|
room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility))
|
||||||
|
|
||||||
// send the events/messages to NATS to create the rooms
|
// send the events/messages to NATS to create the rooms
|
||||||
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)})
|
beforeJoinBody := fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)
|
||||||
|
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": beforeJoinBody})
|
||||||
eventsToSend := append(room.Events(), beforeJoinEv)
|
eventsToSend := append(room.Events(), beforeJoinEv)
|
||||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
|
||||||
t.Fatalf("failed to send events: %v", err)
|
t.Fatalf("failed to send events: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(100 * time.Millisecond) // TODO: find a better way
|
syncUntil(t, base, aliceDev.AccessToken, false,
|
||||||
|
func(syncBody string) bool {
|
||||||
|
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(content.body=="%s")`, room.ID, beforeJoinBody)
|
||||||
|
return gjson.Get(syncBody, path).Exists()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
// There is only one event, we expect only to be able to see this, if the room is world_readable
|
// There is only one event, we expect only to be able to see this, if the room is world_readable
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
@ -449,14 +467,20 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
|
||||||
inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID))
|
inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID))
|
||||||
afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)})
|
afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)})
|
||||||
joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID))
|
joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID))
|
||||||
msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)})
|
afterJoinBody := fmt.Sprintf("After join in a %s room", tc.historyVisibility)
|
||||||
|
msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": afterJoinBody})
|
||||||
|
|
||||||
eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv)
|
eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv)
|
||||||
|
|
||||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
|
||||||
t.Fatalf("failed to send events: %v", err)
|
t.Fatalf("failed to send events: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(100 * time.Millisecond) // TODO: find a better way
|
syncUntil(t, base, aliceDev.AccessToken, false,
|
||||||
|
func(syncBody string) bool {
|
||||||
|
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(content.body=="%s")`, room.ID, afterJoinBody)
|
||||||
|
return gjson.Get(syncBody, path).Exists()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
// Verify the messages after/before invite are visible or not
|
// Verify the messages after/before invite are visible or not
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
|
|
@ -511,8 +535,8 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
|
||||||
AccountType: userapi.AccountTypeUser,
|
AccountType: userapi.AccountTypeUser,
|
||||||
}
|
}
|
||||||
|
|
||||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
|
||||||
defer close()
|
defer baseClose()
|
||||||
|
|
||||||
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||||
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
||||||
|
|
@ -607,7 +631,14 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
|
||||||
t.Fatalf("unable to send to device message: %v", err)
|
t.Fatalf("unable to send to device message: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep((time.Millisecond * 15) * time.Duration(tc.sendMessagesCount)) // wait a bit, so the messages can be processed
|
|
||||||
|
syncUntil(t, base, alice.AccessToken,
|
||||||
|
len(tc.want) == 0,
|
||||||
|
func(body string) bool {
|
||||||
|
return gjson.Get(body, fmt.Sprintf(`to_device.events.#(content.dummy=="message %d")`, msgCounter)).Exists()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
// Execute a /sync request, recording the response
|
// Execute a /sync request, recording the response
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
||||||
|
|
@ -630,6 +661,42 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func syncUntil(t *testing.T,
|
||||||
|
base *base.BaseDendrite, accessToken string,
|
||||||
|
skip bool,
|
||||||
|
checkFunc func(syncBody string) bool,
|
||||||
|
) {
|
||||||
|
if checkFunc == nil {
|
||||||
|
t.Fatalf("No checkFunc defined")
|
||||||
|
}
|
||||||
|
if skip {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// loop on /sync until we receive the last send message or timeout after 5 seconds, since we don't know if the message made it
|
||||||
|
// to the syncAPI when hitting /sync
|
||||||
|
done := make(chan bool)
|
||||||
|
defer close(done)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
||||||
|
"access_token": accessToken,
|
||||||
|
"timeout": "1000",
|
||||||
|
})))
|
||||||
|
if checkFunc(w.Body.String()) {
|
||||||
|
done <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
t.Fatalf("Timed out waiting for messages")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
|
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
|
||||||
result := make([]*nats.Msg, len(input))
|
result := make([]*nats.Msg, len(input))
|
||||||
for i, ev := range input {
|
for i, ev := range input {
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ var (
|
||||||
type StateDelta struct {
|
type StateDelta struct {
|
||||||
RoomID string
|
RoomID string
|
||||||
StateEvents []*gomatrixserverlib.HeaderedEvent
|
StateEvents []*gomatrixserverlib.HeaderedEvent
|
||||||
|
NewlyJoined bool
|
||||||
Membership string
|
Membership string
|
||||||
// The PDU stream position of the latest membership event for this user, if applicable.
|
// The PDU stream position of the latest membership event for this user, if applicable.
|
||||||
// Can be 0 if there is no membership event in this delta.
|
// Can be 0 if there is no membership event in this delta.
|
||||||
|
|
|
||||||
|
|
@ -144,7 +144,6 @@ Server correctly handles incoming m.device_list_update
|
||||||
If remote user leaves room, changes device and rejoins we see update in sync
|
If remote user leaves room, changes device and rejoins we see update in sync
|
||||||
If remote user leaves room, changes device and rejoins we see update in /keys/changes
|
If remote user leaves room, changes device and rejoins we see update in /keys/changes
|
||||||
If remote user leaves room we no longer receive device updates
|
If remote user leaves room we no longer receive device updates
|
||||||
If a device list update goes missing, the server resyncs on the next one
|
|
||||||
Server correctly resyncs when client query keys and there is no remote cache
|
Server correctly resyncs when client query keys and there is no remote cache
|
||||||
Server correctly resyncs when server leaves and rejoins a room
|
Server correctly resyncs when server leaves and rejoins a room
|
||||||
Device list doesn't change if remote server is down
|
Device list doesn't change if remote server is down
|
||||||
|
|
@ -633,7 +632,6 @@ Test that rejected pushers are removed.
|
||||||
Trying to add push rule with no scope fails with 400
|
Trying to add push rule with no scope fails with 400
|
||||||
Trying to add push rule with invalid scope fails with 400
|
Trying to add push rule with invalid scope fails with 400
|
||||||
Forward extremities remain so even after the next events are populated as outliers
|
Forward extremities remain so even after the next events are populated as outliers
|
||||||
If a device list update goes missing, the server resyncs on the next one
|
|
||||||
uploading self-signing key notifies over federation
|
uploading self-signing key notifies over federation
|
||||||
uploading signed devices gets propagated over federation
|
uploading signed devices gets propagated over federation
|
||||||
Device list doesn't change if remote server is down
|
Device list doesn't change if remote server is down
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL str
|
||||||
if withTLS {
|
if withTLS {
|
||||||
certFile := filepath.Join(t.TempDir(), "dendrite.cert")
|
certFile := filepath.Join(t.TempDir(), "dendrite.cert")
|
||||||
keyFile := filepath.Join(t.TempDir(), "dendrite.key")
|
keyFile := filepath.Join(t.TempDir(), "dendrite.key")
|
||||||
err = NewTLSKey(keyFile, certFile)
|
err = NewTLSKey(keyFile, certFile, 1024)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to make TLS key: %s", err)
|
t.Errorf("failed to make TLS key: %s", err)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
19
test/keys.go
19
test/keys.go
|
|
@ -15,6 +15,7 @@
|
||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
|
@ -44,6 +45,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return SaveMatrixKey(matrixKeyPath, data[3:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveMatrixKey(matrixKeyPath string, data ed25519.PrivateKey) error {
|
||||||
keyOut, err := os.OpenFile(matrixKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
keyOut, err := os.OpenFile(matrixKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -62,15 +67,15 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]),
|
"Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]),
|
||||||
},
|
},
|
||||||
Bytes: data[3:],
|
Bytes: data,
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const certificateDuration = time.Hour * 24 * 365 * 10
|
const certificateDuration = time.Hour * 24 * 365 * 10
|
||||||
|
|
||||||
func generateTLSTemplate(dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) {
|
func generateTLSTemplate(dnsNames []string, bitSize int) (*rsa.PrivateKey, *x509.Certificate, error) {
|
||||||
priv, err := rsa.GenerateKey(rand.Reader, 4096)
|
priv, err := rsa.GenerateKey(rand.Reader, bitSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -118,8 +123,8 @@ func writePrivateKey(tlsKeyPath string, priv *rsa.PrivateKey) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file.
|
// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file.
|
||||||
func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
func NewTLSKey(tlsKeyPath, tlsCertPath string, keySize int) error {
|
||||||
priv, template, err := generateTLSTemplate(nil)
|
priv, template, err := generateTLSTemplate(nil, keySize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -136,8 +141,8 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
||||||
return writePrivateKey(tlsKeyPath, priv)
|
return writePrivateKey(tlsKeyPath, priv)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string) error {
|
func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string, keySize int) error {
|
||||||
priv, template, err := generateTLSTemplate([]string{serverName})
|
priv, template, err := generateTLSTemplate([]string{serverName}, keySize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,6 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f
|
||||||
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() {
|
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() {
|
||||||
// cleanup db files. This risks getting out of sync as we add more database strings :(
|
// cleanup db files. This risks getting out of sync as we add more database strings :(
|
||||||
dbFiles := []config.DataSource{
|
dbFiles := []config.DataSource{
|
||||||
cfg.AppServiceAPI.Database.ConnectionString,
|
|
||||||
cfg.FederationAPI.Database.ConnectionString,
|
cfg.FederationAPI.Database.ConnectionString,
|
||||||
cfg.KeyServer.Database.ConnectionString,
|
cfg.KeyServer.Database.ConnectionString,
|
||||||
cfg.MSCs.Database.ConnectionString,
|
cfg.MSCs.Database.ConnectionString,
|
||||||
|
|
|
||||||
|
|
@ -56,15 +56,16 @@ func NewOutputReadUpdateConsumer(
|
||||||
|
|
||||||
func (s *OutputReadUpdateConsumer) Start() error {
|
func (s *OutputReadUpdateConsumer) Start() error {
|
||||||
if err := jetstream.JetStreamConsumer(
|
if err := jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
var read types.ReadUpdate
|
var read types.ReadUpdate
|
||||||
if err := json.Unmarshal(msg.Data, &read); err != nil {
|
if err := json.Unmarshal(msg.Data, &read); err != nil {
|
||||||
log.WithError(err).Error("userapi clientapi consumer: message parse failure")
|
log.WithError(err).Error("userapi clientapi consumer: message parse failure")
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,10 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
|
|
@ -20,9 +24,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/userapi/storage"
|
"github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/userapi/util"
|
"github.com/matrix-org/dendrite/userapi/util"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/nats-io/nats.go"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type OutputStreamEventConsumer struct {
|
type OutputStreamEventConsumer struct {
|
||||||
|
|
@ -64,15 +65,16 @@ func NewOutputStreamEventConsumer(
|
||||||
|
|
||||||
func (s *OutputStreamEventConsumer) Start() error {
|
func (s *OutputStreamEventConsumer) Start() error {
|
||||||
if err := jetstream.JetStreamConsumer(
|
if err := jetstream.JetStreamConsumer(
|
||||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
s.ctx, s.jetstream, s.topic, s.durable, 1,
|
||||||
nats.DeliverAll(), nats.ManualAck(),
|
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||||
|
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||||
var output types.StreamedEvent
|
var output types.StreamedEvent
|
||||||
output.Event = &gomatrixserverlib.HeaderedEvent{}
|
output.Event = &gomatrixserverlib.HeaderedEvent{}
|
||||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||||
|
|
@ -529,7 +531,9 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
|
||||||
case "event_id_only":
|
case "event_id_only":
|
||||||
req = pushgateway.NotifyRequest{
|
req = pushgateway.NotifyRequest{
|
||||||
Notification: pushgateway.Notification{
|
Notification: pushgateway.Notification{
|
||||||
Counts: &pushgateway.Counts{},
|
Counts: &pushgateway.Counts{
|
||||||
|
Unread: userNumUnreadNotifs,
|
||||||
|
},
|
||||||
Devices: devices,
|
Devices: devices,
|
||||||
EventID: event.EventID(),
|
EventID: event.EventID(),
|
||||||
RoomID: event.RoomID(),
|
RoomID: event.RoomID(),
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,6 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/appservice/types"
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
|
|
@ -454,7 +453,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
||||||
// Create a dummy device for AS user
|
// Create a dummy device for AS user
|
||||||
dev := api.Device{
|
dev := api.Device{
|
||||||
// Use AS dummy device ID
|
// Use AS dummy device ID
|
||||||
ID: types.AppServiceDeviceID,
|
ID: "AS_Device",
|
||||||
// AS dummy device has AS's token.
|
// AS dummy device has AS's token.
|
||||||
AccessToken: token,
|
AccessToken: token,
|
||||||
AppserviceID: appService.ID,
|
AppserviceID: appService.ID,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue