Merge branch 'main' into patch-1

This commit is contained in:
Neil Alexander 2022-09-01 09:26:46 +01:00 committed by GitHub
commit 0624d4a643
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
94 changed files with 904 additions and 1849 deletions

View file

@ -7,6 +7,7 @@ on:
pull_request: pull_request:
release: release:
types: [published] types: [published]
workflow_dispatch:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref }}
@ -375,6 +376,8 @@ jobs:
# Build initial Dendrite image # Build initial Dendrite image
- run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . - run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile .
working-directory: dendrite working-directory: dendrite
env:
DOCKER_BUILDKIT: 1
# Run Complement # Run Complement
- run: | - run: |

View file

@ -1,5 +1,30 @@
# Changelog # Changelog
## Dendrite 0.9.5 (2022-08-25)
### Fixes
* The roomserver will now correctly unreject previously rejected events if necessary when reprocessing
* The handling of event soft-failure has been improved on the roomserver input by no longer applying rejection rules and still calculating state before the event if possible
* The federation `/state` and `/state_ids` endpoints should now return the correct error code when the state isn't known instead of returning a HTTP 500
* The federation `/event` should now return outlier events correctly instead of returning a HTTP 500
* A bug in the federation backoff allowing zero intervals has been corrected
* The `create-account` utility will no longer error if the homeserver URL ends in a trailing slash
* A regression in `/sync` introduced in 0.9.4 should be fixed
## Dendrite 0.9.4 (2022-08-19)
### Fixes
* A bug in the roomserver around handling rejected outliers has been fixed
* Backfilled events will now use the correct history visibility where possible
* The device list updater backoff has been fixed, which should reduce the number of outbound HTTP requests and `Failed to query device keys for some users` log entries for dead servers
* The `/sync` endpoint will no longer incorrectly return room entries for retired invites which could cause some rooms to show up in the client "Historical" section
* The `/createRoom` endpoint will now correctly populate `is_direct` in invite membership events, which may help clients to classify direct messages correctly
* The `create-account` tool will now log an error if the shared secret is not set in the Dendrite config
* A couple of minor bugs have been fixed in the membership lazy-loading
* Queued EDUs in the federation API are now cached properly
## Dendrite 0.9.3 (2022-08-15) ## Dendrite 0.9.3 (2022-08-15)
### Important ### Important

View file

@ -1,10 +0,0 @@
# Application Service
This component interfaces with external [Application
Services](https://matrix.org/docs/spec/application_service/unstable.html).
This includes any HTTP endpoints that application services call, as well as talking
to any HTTP endpoints that application services provide themselves.
## Consumers
This component consumes and filters events from the Roomserver Kafka stream, passing on any necessary events to subscribing application services.

View file

@ -18,7 +18,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -28,9 +27,6 @@ import (
"github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/consumers"
"github.com/matrix-org/dendrite/appservice/inthttp" "github.com/matrix-org/dendrite/appservice/inthttp"
"github.com/matrix-org/dendrite/appservice/query" "github.com/matrix-org/dendrite/appservice/query"
"github.com/matrix-org/dendrite/appservice/storage"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/appservice/workers"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -59,57 +55,40 @@ func NewInternalAPI(
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
}, },
} }
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) // Create appserivce query API with an HTTP client that will be used for all
// outbound and inbound requests (inbound only for the internal API)
appserviceQueryAPI := &query.AppServiceQueryAPI{
HTTPClient: client,
Cfg: &base.Cfg.AppServiceAPI,
}
// Create a connection to the appservice postgres DB if len(base.Cfg.Derived.ApplicationServices) == 0 {
appserviceDB, err := storage.NewDatabase(base, &base.Cfg.AppServiceAPI.Database) return appserviceQueryAPI
if err != nil {
logrus.WithError(err).Panicf("failed to connect to appservice db")
} }
// Wrap application services in a type that relates the application service and // Wrap application services in a type that relates the application service and
// a sync.Cond object that can be used to notify workers when there are new // a sync.Cond object that can be used to notify workers when there are new
// events to be sent out. // events to be sent out.
workerStates := make([]types.ApplicationServiceWorkerState, len(base.Cfg.Derived.ApplicationServices)) for _, appservice := range base.Cfg.Derived.ApplicationServices {
for i, appservice := range base.Cfg.Derived.ApplicationServices {
m := sync.Mutex{}
ws := types.ApplicationServiceWorkerState{
AppService: appservice,
Cond: sync.NewCond(&m),
}
workerStates[i] = ws
// Create bot account for this AS if it doesn't already exist // Create bot account for this AS if it doesn't already exist
if err = generateAppServiceAccount(userAPI, appservice); err != nil { if err := generateAppServiceAccount(userAPI, appservice); err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"appservice": appservice.ID, "appservice": appservice.ID,
}).WithError(err).Panicf("failed to generate bot account for appservice") }).WithError(err).Panicf("failed to generate bot account for appservice")
} }
} }
// Create appserivce query API with an HTTP client that will be used for all
// outbound and inbound requests (inbound only for the internal API)
appserviceQueryAPI := &query.AppServiceQueryAPI{
HTTPClient: client,
Cfg: base.Cfg,
}
// Only consume if we actually have ASes to track, else we'll just chew cycles needlessly. // Only consume if we actually have ASes to track, else we'll just chew cycles needlessly.
// We can't add ASes at runtime so this is safe to do. // We can't add ASes at runtime so this is safe to do.
if len(workerStates) > 0 { js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
consumer := consumers.NewOutputRoomEventConsumer( consumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, base.Cfg, js, appserviceDB, base.ProcessContext, &base.Cfg.AppServiceAPI,
rsAPI, workerStates, client, js, rsAPI,
) )
if err := consumer.Start(); err != nil { if err := consumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") logrus.WithError(err).Panicf("failed to start appservice roomserver consumer")
} }
}
// Create application service transaction workers
if err := workers.SetupTransactionWorkers(client, appserviceDB, workerStates); err != nil {
logrus.WithError(err).Panicf("failed to start app service transaction workers")
}
return appserviceQueryAPI return appserviceQueryAPI
} }

View file

@ -15,17 +15,22 @@
package consumers package consumers
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"math"
"net/http"
"net/url"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/matrix-org/dendrite/appservice/storage"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -33,65 +38,83 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context ctx context.Context
cfg *config.AppServiceAPI
client *http.Client
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string
topic string topic string
asDB storage.Database
rsAPI api.AppserviceRoomserverAPI rsAPI api.AppserviceRoomserverAPI
serverName string }
workerStates []types.ApplicationServiceWorkerState
type appserviceState struct {
*config.ApplicationService
backoff int
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call
// Start() to begin consuming from room servers. // Start() to begin consuming from room servers.
func NewOutputRoomEventConsumer( func NewOutputRoomEventConsumer(
process *process.ProcessContext, process *process.ProcessContext,
cfg *config.Dendrite, cfg *config.AppServiceAPI,
client *http.Client,
js nats.JetStreamContext, js nats.JetStreamContext,
appserviceDB storage.Database,
rsAPI api.AppserviceRoomserverAPI, rsAPI api.AppserviceRoomserverAPI,
workerStates []types.ApplicationServiceWorkerState,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
cfg: cfg,
client: client,
jetstream: js, jetstream: js,
durable: cfg.Global.JetStream.Durable("AppserviceRoomserverConsumer"), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
topic: cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent),
asDB: appserviceDB,
rsAPI: rsAPI, rsAPI: rsAPI,
serverName: string(cfg.Global.ServerName),
workerStates: workerStates,
} }
} }
// Start consuming from room servers // Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
return jetstream.JetStreamConsumer( for _, as := range s.cfg.Derived.ApplicationServices {
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, appsvc := as
nats.DeliverAll(), nats.ManualAck(), state := &appserviceState{
) ApplicationService: &appsvc,
}
token := jetstream.Tokenise(as.ID)
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic,
s.cfg.Matrix.JetStream.Durable("Appservice_"+token),
50, // maximum number of events to send in a single transaction
func(ctx context.Context, msgs []*nats.Msg) bool {
return s.onMessage(ctx, state, msgs)
},
nats.DeliverNew(), nats.ManualAck(),
); err != nil {
return fmt.Errorf("failed to create %q consumer: %w", token, err)
}
}
return nil
} }
// onMessage is called when the appservice component receives a new event from // onMessage is called when the appservice component receives a new event from
// the room server output log. // the room server output log.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(
ctx context.Context, state *appserviceState, msgs []*nats.Msg,
) bool {
log.WithField("appservice", state.ID).Tracef("Appservice worker received %d message(s) from roomserver", len(msgs))
events := make([]*gomatrixserverlib.HeaderedEvent, 0, len(msgs))
for _, msg := range msgs {
// Parse out the event JSON // Parse out the event JSON
var output api.OutputEvent var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream // If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure") log.WithField("appservice", state.ID).WithError(err).Errorf("Appservice failed to parse message, ignoring")
return true continue
}
switch output.Type {
case api.OutputTypeNewRoomEvent:
if output.NewRoomEvent == nil || !s.appserviceIsInterestedInEvent(ctx, output.NewRoomEvent.Event, state.ApplicationService) {
continue
} }
log.WithFields(log.Fields{
"type": output.Type,
}).Debug("Got a message in OutputRoomEventConsumer")
events := []*gomatrixserverlib.HeaderedEvent{}
if output.Type == api.OutputTypeNewRoomEvent && output.NewRoomEvent != nil {
newEventID := output.NewRoomEvent.Event.EventID()
events = append(events, output.NewRoomEvent.Event) events = append(events, output.NewRoomEvent.Event)
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
newEventID := output.NewRoomEvent.Event.EventID()
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
} }
@ -103,105 +126,103 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg)
} }
if len(eventsReq.EventIDs) > 0 { if len(eventsReq.EventIDs) > 0 {
if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil {
log.WithError(err).Errorf("s.rsAPI.QueryEventsByID failed")
return false return false
} }
events = append(events, eventsRes.Events...) events = append(events, eventsRes.Events...)
} }
} }
} else if output.Type == api.OutputTypeNewInviteEvent && output.NewInviteEvent != nil {
case api.OutputTypeNewInviteEvent:
if output.NewInviteEvent == nil {
continue
}
events = append(events, output.NewInviteEvent.Event) events = append(events, output.NewInviteEvent.Event)
} else {
log.WithFields(log.Fields{ default:
"type": output.Type, continue
}).Debug("appservice OutputRoomEventConsumer ignoring event", string(msg.Data)) }
}
// If there are no events selected for sending then we should
// ack the messages so that we don't get sent them again in the
// future.
if len(events) == 0 {
return true return true
} }
// Send event to any relevant application services // Send event to any relevant application services. If we hit
if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { // an error here, return false, so that we negatively ack.
log.WithError(err).Errorf("roomserver output log: filter error") log.WithField("appservice", state.ID).Debugf("Appservice worker sending %d events(s) from roomserver", len(events))
return true return s.sendEvents(ctx, state, events) == nil
} }
return true // sendEvents passes events to the appservice by using the transactions
} // endpoint. It will block for the backoff period if necessary.
func (s *OutputRoomEventConsumer) sendEvents(
// filterRoomserverEvents takes in events and decides whether any of them need ctx context.Context, state *appserviceState,
// to be passed on to an external application service. It does this by checking
// each namespace of each registered application service, and if there is a
// match, adds the event to the queue for events to be sent to a particular
// application service.
func (s *OutputRoomEventConsumer) filterRoomserverEvents(
ctx context.Context,
events []*gomatrixserverlib.HeaderedEvent, events []*gomatrixserverlib.HeaderedEvent,
) error { ) error {
for _, ws := range s.workerStates { // Create the transaction body.
for _, event := range events { transaction, err := json.Marshal(
// Check if this event is interesting to this application service gomatrixserverlib.ApplicationServiceTransaction{
if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) { Events: gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll),
// Queue this event to be sent off to the application service },
if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil { )
log.WithError(err).Warn("failed to insert incoming event into appservices database") if err != nil {
return err return err
} else {
// Tell our worker to send out new messages by updating remaining message
// count and waking them up with a broadcast
ws.NotifyNewEvents()
}
}
}
} }
// TODO: We should probably be more intelligent and pick something not
// in the control of the event. A NATS timestamp header or something maybe.
txnID := events[0].Event.OriginServerTS()
// Send the transaction to the appservice.
// https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid
address := fmt.Sprintf("%s/transactions/%d?access_token=%s", state.URL, txnID, url.QueryEscape(state.HSToken))
req, err := http.NewRequestWithContext(ctx, "PUT", address, bytes.NewBuffer(transaction))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return state.backoffAndPause(err)
}
// If the response was fine then we can clear any backoffs in place and
// report that everything was OK. Otherwise, back off for a while.
switch resp.StatusCode {
case http.StatusOK:
state.backoff = 0
default:
return state.backoffAndPause(fmt.Errorf("received HTTP status code %d from appservice", resp.StatusCode))
}
return nil return nil
} }
// appserviceJoinedAtEvent returns a boolean depending on whether a given // backoff pauses the calling goroutine for a 2^some backoff exponent seconds
// appservice has membership at the time a given event was created. func (s *appserviceState) backoffAndPause(err error) error {
func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool { if s.backoff < 6 {
// TODO: This is only checking the current room state, not the state at s.backoff++
// the event in question. Pretty sure this is what Synapse does too, but
// until we have a lighter way of checking the state before the event that
// doesn't involve state res, then this is probably OK.
membershipReq := &api.QueryMembershipsForRoomRequest{
RoomID: event.RoomID(),
JoinedOnly: true,
} }
membershipRes := &api.QueryMembershipsForRoomResponse{} duration := time.Second * time.Duration(math.Pow(2, float64(s.backoff)))
log.WithField("appservice", s.ID).WithError(err).Errorf("Unable to send transaction to appservice, backing off for %s", duration.String())
// XXX: This could potentially race if the state for the event is not known yet time.Sleep(duration)
// e.g. the event came over federation but we do not have the full state persisted. return err
if err := s.rsAPI.QueryMembershipsForRoom(ctx, membershipReq, membershipRes); err == nil {
for _, ev := range membershipRes.JoinEvents {
var membership gomatrixserverlib.MemberContent
if err = json.Unmarshal(ev.Content, &membership); err != nil || ev.StateKey == nil {
continue
}
if appservice.IsInterestedInUserID(*ev.StateKey) {
return true
}
}
} else {
log.WithFields(log.Fields{
"room_id": event.RoomID(),
}).WithError(err).Errorf("Unable to get membership for room")
}
return false
} }
// appserviceIsInterestedInEvent returns a boolean depending on whether a given // appserviceIsInterestedInEvent returns a boolean depending on whether a given
// event falls within one of a given application service's namespaces. // event falls within one of a given application service's namespaces.
// //
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool {
// No reason to queue events if they'll never be sent to the application switch {
// service case appservice.URL == "":
if appservice.URL == "" {
return false return false
} case appservice.IsInterestedInUserID(event.Sender()):
return true
// Check Room ID and Sender of the event case appservice.IsInterestedInRoomID(event.RoomID()):
if appservice.IsInterestedInUserID(event.Sender()) ||
appservice.IsInterestedInRoomID(event.RoomID()) {
return true return true
} }
@ -222,6 +243,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
} }
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": appservice.ID,
"room_id": event.RoomID(), "room_id": event.RoomID(),
}).WithError(err).Errorf("Unable to get aliases for room") }).WithError(err).Errorf("Unable to get aliases for room")
} }
@ -229,3 +251,44 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
// Check if any of the members in the room match the appservice // Check if any of the members in the room match the appservice
return s.appserviceJoinedAtEvent(ctx, event, appservice) return s.appserviceJoinedAtEvent(ctx, event, appservice)
} }
// appserviceJoinedAtEvent returns a boolean depending on whether a given
// appservice has membership at the time a given event was created.
func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool {
// TODO: This is only checking the current room state, not the state at
// the event in question. Pretty sure this is what Synapse does too, but
// until we have a lighter way of checking the state before the event that
// doesn't involve state res, then this is probably OK.
membershipReq := &api.QueryMembershipsForRoomRequest{
RoomID: event.RoomID(),
JoinedOnly: true,
}
membershipRes := &api.QueryMembershipsForRoomResponse{}
// XXX: This could potentially race if the state for the event is not known yet
// e.g. the event came over federation but we do not have the full state persisted.
if err := s.rsAPI.QueryMembershipsForRoom(ctx, membershipReq, membershipRes); err == nil {
for _, ev := range membershipRes.JoinEvents {
switch {
case ev.StateKey == nil:
continue
case ev.Type != gomatrixserverlib.MRoomMember:
continue
}
var membership gomatrixserverlib.MemberContent
err = json.Unmarshal(ev.Content, &membership)
switch {
case err != nil:
continue
case membership.Membership == gomatrixserverlib.Join:
return true
}
}
} else {
log.WithFields(log.Fields{
"appservice": appservice.ID,
"room_id": event.RoomID(),
}).WithError(err).Errorf("Unable to get membership for room")
}
return false
}

View file

@ -33,7 +33,7 @@ const userIDExistsPath = "/users/"
// AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI // AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI
type AppServiceQueryAPI struct { type AppServiceQueryAPI struct {
HTTPClient *http.Client HTTPClient *http.Client
Cfg *config.Dendrite Cfg *config.AppServiceAPI
} }
// RoomAliasExists performs a request to '/room/{roomAlias}' on all known // RoomAliasExists performs a request to '/room/{roomAlias}' on all known

View file

@ -1,30 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
StoreEvent(ctx context.Context, appServiceID string, event *gomatrixserverlib.HeaderedEvent) error
GetEventsWithAppServiceID(ctx context.Context, appServiceID string, limit int) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error)
CountEventsWithAppServiceID(ctx context.Context, appServiceID string) (int, error)
UpdateTxnIDForEvents(ctx context.Context, appserviceID string, maxID, txnID int) error
RemoveEventsBeforeAndIncludingID(ctx context.Context, appserviceID string, eventTableID int) error
GetLatestTxnID(ctx context.Context) (int, error)
}

View file

@ -1,256 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const appserviceEventsSchema = `
-- Stores events to be sent to application services
CREATE TABLE IF NOT EXISTS appservice_events (
-- An auto-incrementing id unique to each event in the table
id BIGSERIAL NOT NULL PRIMARY KEY,
-- The ID of the application service the event will be sent to
as_id TEXT NOT NULL,
-- JSON representation of the event
headered_event_json TEXT NOT NULL,
-- The ID of the transaction that this event is a part of
txn_id BIGINT NOT NULL
);
CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
`
const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)"
const updateTxnIDForEventsSQL = "" +
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
const deleteEventsBeforeAndIncludingIDSQL = "" +
"DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2"
const (
// A transaction ID number that no transaction should ever have. Used for
// checking again the default value.
invalidTxnID = -2
)
type eventsStatements struct {
selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt
updateTxnIDForEventsStmt *sql.Stmt
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
}
func (s *eventsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(appserviceEventsSchema)
if err != nil {
return
}
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
return
}
if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil {
return
}
if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil {
return
}
return
}
// selectEventsByApplicationServiceID takes in an application service ID and
// returns a slice of events that need to be sent to that application service,
// as well as an int later used to remove these same events from the database
// once successfully sent to an application service.
func (s *eventsStatements) selectEventsByApplicationServiceID(
ctx context.Context,
applicationServiceID string,
limit int,
) (
txnID, maxID int,
events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool,
err error,
) {
defer func() {
if err != nil {
log.WithFields(log.Fields{
"appservice": applicationServiceID,
}).WithError(err).Fatalf("appservice unable to select new events to send")
}
}()
// Retrieve events from the database. Unsuccessfully sent events first
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
if err != nil {
return
}
defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
if err != nil {
return
}
return
}
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
// Iterate through each row and store event contents
// If txn_id changes dramatically, we've switched from collecting old events to
// new ones. Send back those events first.
lastTxnID := invalidTxnID
for eventsProcessed := 0; eventRows.Next(); {
var event gomatrixserverlib.HeaderedEvent
var eventJSON []byte
var id int
err = eventRows.Scan(
&id,
&eventJSON,
&txnID,
)
if err != nil {
return nil, 0, 0, false, err
}
// Unmarshal eventJSON
if err = json.Unmarshal(eventJSON, &event); err != nil {
return nil, 0, 0, false, err
}
// If txnID has changed on this event from the previous event, then we've
// reached the end of a transaction's events. Return only those events.
if lastTxnID > invalidTxnID && lastTxnID != txnID {
return events, maxID, lastTxnID, true, nil
}
lastTxnID = txnID
// Limit events that aren't part of an old transaction
if txnID == -1 {
// Return if we've hit the limit
if eventsProcessed++; eventsProcessed > limit {
return events, maxID, lastTxnID, true, nil
}
}
if id > maxID {
maxID = id
}
// Portion of the event that is unsigned due to rapid change
// TODO: Consider removing age as not many app services use it
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
return nil, 0, 0, false, err
}
events = append(events, event)
}
return
}
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) countEventsByApplicationServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
var count int
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return count, nil
}
// insertEvent inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) insertEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) (err error) {
// Convert event to JSON before inserting
eventJSON, err := json.Marshal(event)
if err != nil {
return err
}
_, err = s.insertEventStmt.ExecContext(
ctx,
appServiceID,
eventJSON,
-1, // No transaction ID yet
)
return
}
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
// before sending them to an AppService. Referenced before sending to make sure
// we aren't constructing multiple transactions with the same events.
func (s *eventsStatements) updateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) (err error) {
_, err = s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
return
}
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) (err error) {
_, err = s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
return
}

View file

@ -1,115 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
// Import postgres database driver
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores events intended to be later sent to application services
type Database struct {
events eventsStatements
txnID txnStatements
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) {
var result Database
var err error
if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
if err := d.events.prepare(d.db); err != nil {
return err
}
return d.txnID.prepare(d.db)
}
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
// for a transaction worker to pull and later send to an application service.
func (d *Database) StoreEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) error {
return d.events.insertEvent(ctx, appServiceID, event)
}
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
// be sent to an application service given its ID.
func (d *Database) GetEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
}
// CountEventsWithAppServiceID returns the number of events destined for an
// application service given its ID.
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
}
// UpdateTxnIDForEvents takes in an application service ID and a
// and stores them in the DB, unless the pair already exists, in
// which case it updates them.
func (d *Database) UpdateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) error {
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
}
// RemoveEventsBeforeAndIncludingID removes all events from the database that
// are less than or equal to a given maximum ID. IDs here are implemented as a
// serial, thus this should always delete events in chronological order.
func (d *Database) RemoveEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) error {
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
}
// GetLatestTxnID returns the latest available transaction id
func (d *Database) GetLatestTxnID(
ctx context.Context,
) (int, error) {
return d.txnID.selectTxnID(ctx)
}

View file

@ -1,53 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
)
const txnIDSchema = `
-- Keeps a count of the current transaction ID
CREATE SEQUENCE IF NOT EXISTS txn_id_counter START 1;
`
const selectTxnIDSQL = "SELECT nextval('txn_id_counter')"
type txnStatements struct {
selectTxnIDStmt *sql.Stmt
}
func (s *txnStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(txnIDSchema)
if err != nil {
return
}
if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil {
return
}
return
}
// selectTxnID selects the latest ascending transaction ID
func (s *txnStatements) selectTxnID(
ctx context.Context,
) (txnID int, err error) {
err = s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
return
}

View file

@ -1,267 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const appserviceEventsSchema = `
-- Stores events to be sent to application services
CREATE TABLE IF NOT EXISTS appservice_events (
-- An auto-incrementing id unique to each event in the table
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The ID of the application service the event will be sent to
as_id TEXT NOT NULL,
-- JSON representation of the event
headered_event_json TEXT NOT NULL,
-- The ID of the transaction that this event is a part of
txn_id INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
`
const selectEventsByApplicationServiceIDSQL = "" +
"SELECT id, headered_event_json, txn_id " +
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
const countEventsByApplicationServiceIDSQL = "" +
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
const insertEventSQL = "" +
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
"VALUES ($1, $2, $3)"
const updateTxnIDForEventsSQL = "" +
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
const deleteEventsBeforeAndIncludingIDSQL = "" +
"DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2"
const (
// A transaction ID number that no transaction should ever have. Used for
// checking again the default value.
invalidTxnID = -2
)
type eventsStatements struct {
db *sql.DB
writer sqlutil.Writer
selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt
updateTxnIDForEventsStmt *sql.Stmt
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
}
func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(appserviceEventsSchema)
if err != nil {
return
}
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
return
}
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
return
}
if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil {
return
}
if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil {
return
}
return
}
// selectEventsByApplicationServiceID takes in an application service ID and
// returns a slice of events that need to be sent to that application service,
// as well as an int later used to remove these same events from the database
// once successfully sent to an application service.
func (s *eventsStatements) selectEventsByApplicationServiceID(
ctx context.Context,
applicationServiceID string,
limit int,
) (
txnID, maxID int,
events []gomatrixserverlib.HeaderedEvent,
eventsRemaining bool,
err error,
) {
defer func() {
if err != nil {
log.WithFields(log.Fields{
"appservice": applicationServiceID,
}).WithError(err).Fatalf("appservice unable to select new events to send")
}
}()
// Retrieve events from the database. Unsuccessfully sent events first
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
if err != nil {
return
}
defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
if err != nil {
return
}
return
}
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
// Iterate through each row and store event contents
// If txn_id changes dramatically, we've switched from collecting old events to
// new ones. Send back those events first.
lastTxnID := invalidTxnID
for eventsProcessed := 0; eventRows.Next(); {
var event gomatrixserverlib.HeaderedEvent
var eventJSON []byte
var id int
err = eventRows.Scan(
&id,
&eventJSON,
&txnID,
)
if err != nil {
return nil, 0, 0, false, err
}
// Unmarshal eventJSON
if err = json.Unmarshal(eventJSON, &event); err != nil {
return nil, 0, 0, false, err
}
// If txnID has changed on this event from the previous event, then we've
// reached the end of a transaction's events. Return only those events.
if lastTxnID > invalidTxnID && lastTxnID != txnID {
return events, maxID, lastTxnID, true, nil
}
lastTxnID = txnID
// Limit events that aren't part of an old transaction
if txnID == -1 {
// Return if we've hit the limit
if eventsProcessed++; eventsProcessed > limit {
return events, maxID, lastTxnID, true, nil
}
}
if id > maxID {
maxID = id
}
// Portion of the event that is unsigned due to rapid change
// TODO: Consider removing age as not many app services use it
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
return nil, 0, 0, false, err
}
events = append(events, event)
}
return
}
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) countEventsByApplicationServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
var count int
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return count, nil
}
// insertEvent inserts an event mapped to its corresponding application service
// IDs into the db.
func (s *eventsStatements) insertEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) (err error) {
// Convert event to JSON before inserting
eventJSON, err := json.Marshal(event)
if err != nil {
return err
}
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.insertEventStmt.ExecContext(
ctx,
appServiceID,
eventJSON,
-1, // No transaction ID yet
)
return err
})
}
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
// before sending them to an AppService. Referenced before sending to make sure
// we aren't constructing multiple transactions with the same events.
func (s *eventsStatements) updateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
return err
})
}
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
return err
})
}

View file

@ -1,114 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
// Import SQLite database driver
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores events intended to be later sent to application services
type Database struct {
events eventsStatements
txnID txnStatements
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*Database, error) {
var result Database
var err error
if result.db, result.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
if err := d.events.prepare(d.db, d.writer); err != nil {
return err
}
return d.txnID.prepare(d.db, d.writer)
}
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
// for a transaction worker to pull and later send to an application service.
func (d *Database) StoreEvent(
ctx context.Context,
appServiceID string,
event *gomatrixserverlib.HeaderedEvent,
) error {
return d.events.insertEvent(ctx, appServiceID, event)
}
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
// be sent to an application service given its ID.
func (d *Database) GetEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
limit int,
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
}
// CountEventsWithAppServiceID returns the number of events destined for an
// application service given its ID.
func (d *Database) CountEventsWithAppServiceID(
ctx context.Context,
appServiceID string,
) (int, error) {
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
}
// UpdateTxnIDForEvents takes in an application service ID and a
// and stores them in the DB, unless the pair already exists, in
// which case it updates them.
func (d *Database) UpdateTxnIDForEvents(
ctx context.Context,
appserviceID string,
maxID, txnID int,
) error {
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
}
// RemoveEventsBeforeAndIncludingID removes all events from the database that
// are less than or equal to a given maximum ID. IDs here are implemented as a
// serial, thus this should always delete events in chronological order.
func (d *Database) RemoveEventsBeforeAndIncludingID(
ctx context.Context,
appserviceID string,
eventTableID int,
) error {
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
}
// GetLatestTxnID returns the latest available transaction id
func (d *Database) GetLatestTxnID(
ctx context.Context,
) (int, error) {
return d.txnID.selectTxnID(ctx)
}

View file

@ -1,82 +0,0 @@
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const txnIDSchema = `
-- Keeps a count of the current transaction ID
CREATE TABLE IF NOT EXISTS appservice_counters (
name TEXT PRIMARY KEY NOT NULL,
last_id INTEGER DEFAULT 1
);
INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1);
`
const selectTxnIDSQL = `
SELECT last_id FROM appservice_counters WHERE name='txn_id'
`
const updateTxnIDSQL = `
UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id'
`
type txnStatements struct {
db *sql.DB
writer sqlutil.Writer
selectTxnIDStmt *sql.Stmt
updateTxnIDStmt *sql.Stmt
}
func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(txnIDSchema)
if err != nil {
return
}
if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil {
return
}
if s.updateTxnIDStmt, err = db.Prepare(updateTxnIDSQL); err != nil {
return
}
return
}
// selectTxnID selects the latest ascending transaction ID
func (s *txnStatements) selectTxnID(
ctx context.Context,
) (txnID int, err error) {
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
if err != nil {
return err
}
_, err = s.updateTxnIDStmt.ExecContext(ctx)
return err
})
return
}

View file

@ -1,40 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/appservice/storage/postgres"
"github.com/matrix-org/dendrite/appservice/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets DB connection parameters
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -1,34 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/appservice/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
)
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -1,64 +0,0 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"sync"
"github.com/matrix-org/dendrite/setup/config"
)
const (
// AppServiceDeviceID is the AS dummy device ID
AppServiceDeviceID = "AS_Device"
)
// ApplicationServiceWorkerState is a type that couples an application service,
// a lockable condition as well as some other state variables, allowing the
// roomserver to notify appservice workers when there are events ready to send
// externally to application services.
type ApplicationServiceWorkerState struct {
AppService config.ApplicationService
Cond *sync.Cond
// Events ready to be sent
EventsReady bool
// Backoff exponent (2^x secs). Max 6, aka 64s.
Backoff int
}
// NotifyNewEvents wakes up all waiting goroutines, notifying that events remain
// in the event queue for this application service worker.
func (a *ApplicationServiceWorkerState) NotifyNewEvents() {
a.Cond.L.Lock()
a.EventsReady = true
a.Cond.Broadcast()
a.Cond.L.Unlock()
}
// FinishEventProcessing marks all events of this worker as being sent to the
// application service.
func (a *ApplicationServiceWorkerState) FinishEventProcessing() {
a.Cond.L.Lock()
a.EventsReady = false
a.Cond.L.Unlock()
}
// WaitForNewEvents causes the calling goroutine to wait on the worker state's
// condition for a broadcast or similar wakeup, if there are no events ready.
func (a *ApplicationServiceWorkerState) WaitForNewEvents() {
a.Cond.L.Lock()
if !a.EventsReady {
a.Cond.Wait()
}
a.Cond.L.Unlock()
}

View file

@ -1,236 +0,0 @@
// Copyright 2018 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package workers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"math"
"net/http"
"net/url"
"time"
"github.com/matrix-org/dendrite/appservice/storage"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
var (
// Maximum size of events sent in each transaction.
transactionBatchSize = 50
)
// SetupTransactionWorkers spawns a separate goroutine for each application
// service. Each of these "workers" handle taking all events intended for their
// app service, batch them up into a single transaction (up to a max transaction
// size), then send that off to the AS's /transactions/{txnID} endpoint. It also
// handles exponentially backing off in case the AS isn't currently available.
func SetupTransactionWorkers(
client *http.Client,
appserviceDB storage.Database,
workerStates []types.ApplicationServiceWorkerState,
) error {
// Create a worker that handles transmitting events to a single homeserver
for _, workerState := range workerStates {
// Don't create a worker if this AS doesn't want to receive events
if workerState.AppService.URL != "" {
go worker(client, appserviceDB, workerState)
}
}
return nil
}
// worker is a goroutine that sends any queued events to the application service
// it is given.
func worker(client *http.Client, db storage.Database, ws types.ApplicationServiceWorkerState) {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).Info("Starting application service")
ctx := context.Background()
// Initial check for any leftover events to send from last time
eventCount, err := db.CountEventsWithAppServiceID(ctx, ws.AppService.ID)
if err != nil {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Fatal("appservice worker unable to read queued events from DB")
return
}
if eventCount > 0 {
ws.NotifyNewEvents()
}
// Loop forever and keep waiting for more events to send
for {
// Wait for more events if we've sent all the events in the database
ws.WaitForNewEvents()
// Batch events up into a transaction
transactionJSON, txnID, maxEventID, eventsRemaining, err := createTransaction(ctx, db, ws.AppService.ID)
if err != nil {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Fatal("appservice worker unable to create transaction")
return
}
// Send the events off to the application service
// Backoff if the application service does not respond
err = send(client, ws.AppService, txnID, transactionJSON)
if err != nil {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Error("unable to send event")
// Backoff
backoff(&ws, err)
continue
}
// We sent successfully, hooray!
ws.Backoff = 0
// Transactions have a maximum event size, so there may still be some events
// left over to send. Keep sending until none are left
if !eventsRemaining {
ws.FinishEventProcessing()
}
// Remove sent events from the DB
err = db.RemoveEventsBeforeAndIncludingID(ctx, ws.AppService.ID, maxEventID)
if err != nil {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Fatal("unable to remove appservice events from the database")
return
}
}
}
// backoff pauses the calling goroutine for a 2^some backoff exponent seconds
func backoff(ws *types.ApplicationServiceWorkerState, err error) {
// Calculate how long to backoff for
backoffDuration := time.Duration(math.Pow(2, float64(ws.Backoff)))
backoffSeconds := time.Second * backoffDuration
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Warnf("unable to send transactions successfully, backing off for %ds",
backoffDuration)
ws.Backoff++
if ws.Backoff > 6 {
ws.Backoff = 6
}
// Backoff
time.Sleep(backoffSeconds)
}
// createTransaction takes in a slice of AS events, stores them in an AS
// transaction, and JSON-encodes the results.
func createTransaction(
ctx context.Context,
db storage.Database,
appserviceID string,
) (
transactionJSON []byte,
txnID, maxID int,
eventsRemaining bool,
err error,
) {
// Retrieve the latest events from the DB (will return old events if they weren't successfully sent)
txnID, maxID, events, eventsRemaining, err := db.GetEventsWithAppServiceID(ctx, appserviceID, transactionBatchSize)
if err != nil {
log.WithFields(log.Fields{
"appservice": appserviceID,
}).WithError(err).Fatalf("appservice worker unable to read queued events from DB")
return
}
// Check if these events do not already have a transaction ID
if txnID == -1 {
// If not, grab next available ID from the DB
txnID, err = db.GetLatestTxnID(ctx)
if err != nil {
return nil, 0, 0, false, err
}
// Mark new events with current transactionID
if err = db.UpdateTxnIDForEvents(ctx, appserviceID, maxID, txnID); err != nil {
return nil, 0, 0, false, err
}
}
var ev []*gomatrixserverlib.HeaderedEvent
for i := range events {
ev = append(ev, &events[i])
}
// Create a transaction and store the events inside
transaction := gomatrixserverlib.ApplicationServiceTransaction{
Events: gomatrixserverlib.HeaderedToClientEvents(ev, gomatrixserverlib.FormatAll),
}
transactionJSON, err = json.Marshal(transaction)
if err != nil {
return
}
return
}
// send sends events to an application service. Returns an error if an OK was not
// received back from the application service or the request timed out.
func send(
client *http.Client,
appservice config.ApplicationService,
txnID int,
transaction []byte,
) (err error) {
// PUT a transaction to our AS
// https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid
address := fmt.Sprintf("%s/transactions/%d?access_token=%s", appservice.URL, txnID, url.QueryEscape(appservice.HSToken))
req, err := http.NewRequest("PUT", address, bytes.NewBuffer(transaction))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer checkNamedErr(resp.Body.Close, &err)
// Check the AS received the events correctly
if resp.StatusCode != http.StatusOK {
// TODO: Handle non-200 error codes from application services
return fmt.Errorf("non-OK status code %d returned from AS", resp.StatusCode)
}
return nil
}
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}

View file

@ -255,7 +255,6 @@ func (m *DendriteMonolith) Start() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-federationsender.db", m.StorageDirectory, prefix)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-federationsender.db", m.StorageDirectory, prefix))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-appservice.db", m.StorageDirectory, prefix))
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}

View file

@ -94,7 +94,6 @@ func (m *DendriteMonolith) Start() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-keyserver.db", m.StorageDirectory)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-keyserver.db", m.StorageDirectory))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-federationsender.db", m.StorageDirectory)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-federationsender.db", m.StorageDirectory))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-appservice.db", m.StorageDirectory))
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.RegistrationDisabled = false

View file

@ -1,3 +1,5 @@
#syntax=docker/dockerfile:1.2
FROM golang:1.18-stretch as build FROM golang:1.18-stretch as build
RUN apt-get update && apt-get install -y sqlite3 RUN apt-get update && apt-get install -y sqlite3
WORKDIR /build WORKDIR /build
@ -8,14 +10,12 @@ RUN mkdir /dendrite
# Utilise Docker caching when downloading dependencies, this stops us needlessly # Utilise Docker caching when downloading dependencies, this stops us needlessly
# downloading dependencies every time. # downloading dependencies every time.
COPY go.mod . RUN --mount=target=. \
COPY go.sum . --mount=type=cache,target=/go/pkg/mod \
RUN go mod download --mount=type=cache,target=/root/.cache/go-build \
go build -o /dendrite ./cmd/generate-config && \
COPY . . go build -o /dendrite ./cmd/generate-keys && \
RUN go build -o /dendrite ./cmd/dendrite-monolith-server go build -o /dendrite ./cmd/dendrite-monolith-server
RUN go build -o /dendrite ./cmd/generate-keys
RUN go build -o /dendrite ./cmd/generate-config
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem
@ -26,7 +26,7 @@ EXPOSE 8008 8448
# At runtime, generate TLS cert based on the CA now mounted at /ca # At runtime, generate TLS cert based on the CA now mounted at /ca
# At runtime, replace the SERVER_NAME with what we are told # At runtime, replace the SERVER_NAME with what we are told
CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}

View file

@ -1,3 +1,5 @@
#syntax=docker/dockerfile:1.2
# A local development Complement dockerfile, to be used with host mounts # A local development Complement dockerfile, to be used with host mounts
# /cache -> Contains the entire dendrite code at Dockerfile build time. Builds binaries but only keeps the generate-* ones. Pre-compilation saves time. # /cache -> Contains the entire dendrite code at Dockerfile build time. Builds binaries but only keeps the generate-* ones. Pre-compilation saves time.
# /dendrite -> Host-mounted sources # /dendrite -> Host-mounted sources
@ -9,11 +11,10 @@
FROM golang:1.18-stretch FROM golang:1.18-stretch
RUN apt-get update && apt-get install -y sqlite3 RUN apt-get update && apt-get install -y sqlite3
WORKDIR /runtime
ENV SERVER_NAME=localhost ENV SERVER_NAME=localhost
EXPOSE 8008 8448 EXPOSE 8008 8448
WORKDIR /runtime
# This script compiles Dendrite for us. # This script compiles Dendrite for us.
RUN echo '\ RUN echo '\
#!/bin/bash -eux \n\ #!/bin/bash -eux \n\
@ -29,25 +30,23 @@ RUN echo '\
RUN echo '\ RUN echo '\
#!/bin/bash -eu \n\ #!/bin/bash -eu \n\
./generate-keys --private-key matrix_key.pem \n\ ./generate-keys --private-key matrix_key.pem \n\
./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\
./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\
./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\
' > run.sh && chmod +x run.sh ' > run.sh && chmod +x run.sh
WORKDIR /cache WORKDIR /cache
# Pre-download deps; we don't need to do this if the GOPATH is mounted.
COPY go.mod .
COPY go.sum .
RUN go mod download
# Build the monolith in /cache - we won't actually use this but will rely on build artifacts to speed # Build the monolith in /cache - we won't actually use this but will rely on build artifacts to speed
# up the real compilation. Build the generate-* binaries in the true /runtime locations. # up the real compilation. Build the generate-* binaries in the true /runtime locations.
# If the generate-* source is changed, this dockerfile needs re-running. # If the generate-* source is changed, this dockerfile needs re-running.
COPY . . RUN --mount=target=. \
RUN go build ./cmd/dendrite-monolith-server && go build -o /runtime ./cmd/generate-keys && go build -o /runtime ./cmd/generate-config --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \
go build -o /runtime ./cmd/generate-config && \
go build -o /runtime ./cmd/generate-keys
WORKDIR /runtime WORKDIR /runtime
CMD /runtime/compile.sh && /runtime/run.sh CMD /runtime/compile.sh && exec /runtime/run.sh

View file

@ -1,3 +1,5 @@
#syntax=docker/dockerfile:1.2
FROM golang:1.18-stretch as build FROM golang:1.18-stretch as build
RUN apt-get update && apt-get install -y postgresql RUN apt-get update && apt-get install -y postgresql
WORKDIR /build WORKDIR /build
@ -26,14 +28,12 @@ RUN mkdir /dendrite
# Utilise Docker caching when downloading dependencies, this stops us needlessly # Utilise Docker caching when downloading dependencies, this stops us needlessly
# downloading dependencies every time. # downloading dependencies every time.
COPY go.mod . RUN --mount=target=. \
COPY go.sum . --mount=type=cache,target=/go/pkg/mod \
RUN go mod download --mount=type=cache,target=/root/.cache/go-build \
go build -o /dendrite ./cmd/generate-config && \
COPY . . go build -o /dendrite ./cmd/generate-keys && \
RUN go build -o /dendrite ./cmd/dendrite-monolith-server go build -o /dendrite ./cmd/dendrite-monolith-server
RUN go build -o /dendrite ./cmd/generate-keys
RUN go build -o /dendrite ./cmd/generate-config
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem
@ -45,10 +45,10 @@ EXPOSE 8008 8448
# At runtime, generate TLS cert based on the CA now mounted at /ca # At runtime, generate TLS cert based on the CA now mounted at /ca
# At runtime, replace the SERVER_NAME with what we are told # At runtime, replace the SERVER_NAME with what we are told
CMD /build/run_postgres.sh && ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ CMD /build/run_postgres.sh && ./generate-keys --keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
# Replace the connection string with a single postgres DB, using user/db = 'postgres' and no password, bump max_conns # Replace the connection string with a single postgres DB, using user/db = 'postgres' and no password, bump max_conns
sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \ sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \
sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \ sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}

View file

@ -276,19 +276,19 @@ type recaptchaResponse struct {
} }
// validateUsername returns an error response if the username is invalid // validateUsername returns an error response if the username is invalid
func validateUsername(username string) *util.JSONResponse { func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if len(username) > maxUsernameLength { if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("'username' >%d characters", maxUsernameLength)), JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
} }
} else if !validUsernameRegex.MatchString(username) { } else if !validUsernameRegex.MatchString(localpart) {
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
} }
} else if username[0] == '_' { // Regex checks its not a zero length string } else if localpart[0] == '_' { // Regex checks its not a zero length string
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"), JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
@ -298,13 +298,13 @@ func validateUsername(username string) *util.JSONResponse {
} }
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service // validateApplicationServiceUsername returns an error response if the username is invalid for an application service
func validateApplicationServiceUsername(username string) *util.JSONResponse { func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
if len(username) > maxUsernameLength { if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("'username' >%d characters", maxUsernameLength)), JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
} }
} else if !validUsernameRegex.MatchString(username) { } else if !validUsernameRegex.MatchString(localpart) {
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
@ -523,7 +523,7 @@ func validateApplicationService(
} }
// Check username application service is trying to register is valid // Check username application service is trying to register is valid
if err := validateApplicationServiceUsername(username); err != nil { if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
return "", err return "", err
} }
@ -604,7 +604,7 @@ func Register(
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
// Spec-compliant case (the access_token is specified and the login type // Spec-compliant case (the access_token is specified and the login type
// is correctly set, so it's an appservice registration) // is correctly set, so it's an appservice registration)
if resErr := validateApplicationServiceUsername(r.Username); resErr != nil { if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil {
return *resErr return *resErr
} }
case accessTokenErr == nil: case accessTokenErr == nil:
@ -617,7 +617,7 @@ func Register(
default: default:
// Spec-compliant case (neither the access_token nor the login type are // Spec-compliant case (neither the access_token nor the login type are
// specified, so it's a normal user registration) // specified, so it's a normal user registration)
if resErr := validateUsername(r.Username); resErr != nil { if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil {
return *resErr return *resErr
} }
} }
@ -1018,7 +1018,7 @@ func RegisterAvailable(
// Squash username to all lowercase letters // Squash username to all lowercase letters
username = strings.ToLower(username) username = strings.ToLower(username)
if err := validateUsername(username); err != nil { if err := validateUsername(username, cfg.Matrix.ServerName); err != nil {
return *err return *err
} }
@ -1059,7 +1059,7 @@ func RegisterAvailable(
} }
} }
func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSecretRegistration, req *http.Request) util.JSONResponse { func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.ClientUserAPI, sr *SharedSecretRegistration, req *http.Request) util.JSONResponse {
ssrr, err := NewSharedSecretRegistrationRequest(req.Body) ssrr, err := NewSharedSecretRegistrationRequest(req.Body)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -1080,7 +1080,7 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec
// downcase capitals // downcase capitals
ssrr.User = strings.ToLower(ssrr.User) ssrr.User = strings.ToLower(ssrr.User)
if resErr := validateUsername(ssrr.User); resErr != nil { if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
return *resErr return *resErr
} }
if resErr := validatePassword(ssrr.Password); resErr != nil { if resErr := validatePassword(ssrr.Password); resErr != nil {

View file

@ -133,7 +133,7 @@ func Setup(
} }
} }
if req.Method == http.MethodPost { if req.Method == http.MethodPost {
return handleSharedSecretRegistration(userAPI, sr, req) return handleSharedSecretRegistration(cfg, userAPI, sr, req)
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusMethodNotAllowed, Code: http.StatusMethodNotAllowed,

View file

@ -66,10 +66,11 @@ var (
resetPassword = flag.Bool("reset-password", false, "Deprecated") resetPassword = flag.Bool("reset-password", false, "Deprecated")
serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.") serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.")
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
) )
var cl = http.Client{ var cl = http.Client{
Timeout: time.Second * 10, Timeout: time.Second * 30,
Transport: http.DefaultTransport, Transport: http.DefaultTransport,
} }
@ -108,6 +109,8 @@ func main() {
logrus.Fatalln(err) logrus.Fatalln(err)
} }
cl.Timeout = *timeout
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
if err != nil { if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error()) logrus.Fatalln("Failed to create the account:", err.Error())
@ -124,8 +127,8 @@ type sharedSecretRegistrationRequest struct {
Admin bool `json:"admin"` Admin bool `json:"admin"`
} }
func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, admin bool) (accesToken string, err error) { func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, admin bool) (accessToken string, err error) {
registerURL := fmt.Sprintf("%s/_synapse/admin/v1/register", serverURL) registerURL := fmt.Sprintf("%s/_synapse/admin/v1/register", strings.Trim(serverURL, "/"))
nonceReq, err := http.NewRequest(http.MethodGet, registerURL, nil) nonceReq, err := http.NewRequest(http.MethodGet, registerURL, nil)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to create http request: %w", err) return "", fmt.Errorf("unable to create http request: %w", err)

View file

@ -24,6 +24,7 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"strings"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -42,6 +43,7 @@ import (
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -70,23 +72,73 @@ func main() {
var pk ed25519.PublicKey var pk ed25519.PublicKey
var sk ed25519.PrivateKey var sk ed25519.PrivateKey
keyfile := *instanceName + ".key" // iterate through the cli args and check if the config flag was set
configFlagSet := false
for _, arg := range os.Args {
if arg == "--config" || arg == "-config" {
configFlagSet = true
break
}
}
cfg := &config.Dendrite{}
// use custom config if config flag is set
if configFlagSet {
cfg = setup.ParseFlags(true)
sk = cfg.Global.PrivateKey
} else {
keyfile := *instanceName + ".pem"
if _, err := os.Stat(keyfile); os.IsNotExist(err) { if _, err := os.Stat(keyfile); os.IsNotExist(err) {
if pk, sk, err = ed25519.GenerateKey(nil); err != nil { oldkeyfile := *instanceName + ".key"
panic(err) if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) {
if err = test.NewMatrixKey(keyfile); err != nil {
panic("failed to generate a new PEM key: " + err.Error())
} }
if err = os.WriteFile(keyfile, sk, 0644); err != nil { if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil {
panic(err) panic("failed to load PEM key: " + err.Error())
} }
} else if err == nil { } else {
if sk, err = os.ReadFile(keyfile); err != nil { if sk, err = os.ReadFile(oldkeyfile); err != nil {
panic(err) panic("failed to read the old private key: " + err.Error())
} }
if len(sk) != ed25519.PrivateKeySize { if len(sk) != ed25519.PrivateKeySize {
panic("the private key is not long enough") panic("the private key is not long enough")
} }
pk = sk.Public().(ed25519.PublicKey) if err := test.SaveMatrixKey(keyfile, sk); err != nil {
panic("failed to convert the private key to PEM format: " + err.Error())
} }
}
} else {
var err error
if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil {
panic("failed to load PEM key: " + err.Error())
}
}
cfg.Defaults(true)
cfg.Global.PrivateKey = sk
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err := cfg.Derive(); err != nil {
panic(err)
}
}
pk = sk.Public().(ed25519.PublicKey)
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
base := base.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
@ -94,7 +146,9 @@ func main() {
pManager := pineconeConnections.NewConnectionManager(pRouter, nil) pManager := pineconeConnections.NewConnectionManager(pRouter, nil)
pMulticast.Start() pMulticast.Start()
if instancePeer != nil && *instancePeer != "" { if instancePeer != nil && *instancePeer != "" {
pManager.AddPeer(*instancePeer) for _, peer := range strings.Split(*instancePeer, ",") {
pManager.AddPeer(strings.Trim(peer, " \t\r\n"))
}
} }
go func() { go func() {
@ -125,29 +179,6 @@ func main() {
} }
}() }()
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
cfg.Global.PrivateKey = sk
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err := cfg.Derive(); err != nil {
panic(err)
}
base := base.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck
federation := conn.CreateFederationClient(base, pQUIC) federation := conn.CreateFederationClient(base, pQUIC)
serverKeyAPI := &signing.YggdrasilKeys{} serverKeyAPI := &signing.YggdrasilKeys{}

View file

@ -86,7 +86,6 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836"} cfg.MSCs.MSCs = []string{"msc2836"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.RegistrationDisabled = false

View file

@ -24,7 +24,6 @@ func main() {
cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName) cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName)
} }
if *dbURI != "" { if *dbURI != "" {
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(*dbURI)
cfg.FederationAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.FederationAPI.Database.ConnectionString = config.DataSource(*dbURI)
cfg.KeyServer.Database.ConnectionString = config.DataSource(*dbURI) cfg.KeyServer.Database.ConnectionString = config.DataSource(*dbURI)
cfg.MSCs.Database.ConnectionString = config.DataSource(*dbURI) cfg.MSCs.Database.ConnectionString = config.DataSource(*dbURI)

View file

@ -38,6 +38,7 @@ var (
authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.") serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.")
keySize = flag.Int("keysize", 4096, "Optional: Create TLS RSA private key with the given key size")
) )
func main() { func main() {
@ -58,12 +59,12 @@ func main() {
log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied") log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied")
} }
if *authorityCertFile == "" && *authorityKeyFile == "" { if *authorityCertFile == "" && *authorityKeyFile == "" {
if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil { if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile, *keySize); err != nil {
panic(err) panic(err)
} }
} else { } else {
// generate the TLS cert/key based on the authority given. // generate the TLS cert/key based on the authority given.
if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile); err != nil { if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile, *keySize); err != nil {
panic(err) panic(err)
} }
} }

View file

@ -132,13 +132,6 @@ app_service_api:
listen: http://[::]:7777 # The listen address for incoming API requests listen: http://[::]:7777 # The listen address for incoming API requests
connect: http://app_service_api:7777 # The connect address for other components to use connect: http://app_service_api:7777 # The connect address for other components to use
# Database configuration for this component.
database:
connection_string: postgresql://username:password@hostname/dendrite_appservice?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# Disable the validation of TLS certificates of appservices. This is # Disable the validation of TLS certificates of appservices. This is
# not recommended in production since it may allow appservice traffic # not recommended in production since it may allow appservice traffic
# to be sent to an insecure endpoint. # to be sent to an insecure endpoint.

View file

@ -67,14 +67,15 @@ func NewKeyChangeConsumer(
// Start consuming from key servers // Start consuming from key servers
func (t *KeyChangeConsumer) Start() error { func (t *KeyChangeConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1,
nats.DeliverAll(), nats.ManualAck(), t.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called in response to a message received on the // onMessage is called in response to a message received on the
// key change events topic from the key server. // key change events topic from the key server.
func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *KeyChangeConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var m api.DeviceMessage var m api.DeviceMessage
if err := json.Unmarshal(msg.Data, &m); err != nil { if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic") logrus.WithError(err).Errorf("failed to read device message from key change topic")

View file

@ -69,14 +69,15 @@ func (t *OutputPresenceConsumer) Start() error {
return nil return nil
} }
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
) )
} }
// onMessage is called in response to a message received on the presence // onMessage is called in response to a message received on the presence
// events topic from the client api. // events topic from the client api.
func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// only send presence events which originated from us // only send presence events which originated from us
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
_, serverName, err := gomatrixserverlib.SplitID('@', userID) _, serverName, err := gomatrixserverlib.SplitID('@', userID)

View file

@ -65,14 +65,15 @@ func NewOutputReceiptConsumer(
// Start consuming from the clientapi // Start consuming from the clientapi
func (t *OutputReceiptConsumer) Start() error { func (t *OutputReceiptConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
) )
} }
// onMessage is called in response to a message received on the receipt // onMessage is called in response to a message received on the receipt
// events topic from the client api. // events topic from the client api.
func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
receipt := syncTypes.OutputReceiptEvent{ receipt := syncTypes.OutputReceiptEvent{
UserID: msg.Header.Get(jetstream.UserID), UserID: msg.Header.Get(jetstream.UserID),
RoomID: msg.Header.Get(jetstream.RoomID), RoomID: msg.Header.Get(jetstream.RoomID),

View file

@ -68,8 +68,8 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
@ -77,7 +77,8 @@ func (s *OutputRoomEventConsumer) Start() error {
// It is unsafe to call this with messages for the same room in multiple gorountines // It is unsafe to call this with messages for the same room in multiple gorountines
// because updates it will likely fail with a types.EventIDMismatchError when it // because updates it will likely fail with a types.EventIDMismatchError when it
// realises that it cannot update the room state using the deltas. // realises that it cannot update the room state using the deltas.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON // Parse out the event JSON
var output api.OutputEvent var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {

View file

@ -63,14 +63,15 @@ func NewOutputSendToDeviceConsumer(
// Start consuming from the client api // Start consuming from the client api
func (t *OutputSendToDeviceConsumer) Start() error { func (t *OutputSendToDeviceConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1,
nats.DeliverAll(), nats.ManualAck(), t.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called in response to a message received on the // onMessage is called in response to a message received on the
// send-to-device events topic from the client api. // send-to-device events topic from the client api.
func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// only send send-to-device events which originated from us // only send send-to-device events which originated from us
sender := msg.Header.Get("sender") sender := msg.Header.Get("sender")
_, originServerName, err := gomatrixserverlib.SplitID('@', sender) _, originServerName, err := gomatrixserverlib.SplitID('@', sender)

View file

@ -62,14 +62,15 @@ func NewOutputTypingConsumer(
// Start consuming from the clientapi // Start consuming from the clientapi
func (t *OutputTypingConsumer) Start() error { func (t *OutputTypingConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
) )
} }
// onMessage is called in response to a message received on the typing // onMessage is called in response to a message received on the typing
// events topic from the client api. // events topic from the client api.
func (t *OutputTypingConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Extract the typing event from msg. // Extract the typing event from msg.
roomID := msg.Header.Get(jetstream.RoomID) roomID := msg.Header.Get(jetstream.RoomID)
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)

View file

@ -329,6 +329,12 @@ func SendJoin(
JSON: jsonerror.NotFound("Room does not exist"), JSON: jsonerror.NotFound("Room does not exist"),
} }
} }
if !stateAndAuthChainResponse.StateKnown {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("State not known"),
}
}
// Check if the user is already in the room. If they're already in then // Check if the user is already in the room. If they're already in then
// there isn't much point in sending another join event into the room. // there isn't much point in sending another join event into the room.

View file

@ -135,6 +135,12 @@ func getState(
return nil, nil, &resErr return nil, nil, &resErr
} }
if !response.StateKnown {
return nil, nil, &util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("State not known"),
}
}
if response.IsRejected { if response.IsRejected {
return nil, nil, &util.JSONResponse{ return nil, nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -5,10 +5,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"go.uber.org/atomic" "go.uber.org/atomic"
"github.com/matrix-org/dendrite/federationapi/storage"
) )
// Statistics contains information about all of the remote federated // Statistics contains information about all of the remote federated
@ -126,13 +127,13 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
go func() { go func() {
until, ok := s.backoffUntil.Load().(time.Time) until, ok := s.backoffUntil.Load().(time.Time)
if ok { if ok && !until.IsZero() {
select { select {
case <-time.After(time.Until(until)): case <-time.After(time.Until(until)):
case <-s.interrupt: case <-s.interrupt:
} }
}
s.backoffStarted.Store(false) s.backoffStarted.Store(false)
}
}() }()
} }

View file

@ -110,6 +110,7 @@ func (d *Database) GetPendingEDUs(
return fmt.Errorf("json.Unmarshal: %w", err) return fmt.Errorf("json.Unmarshal: %w", err)
} }
edus[&Receipt{nid}] = &event edus[&Receipt{nid}] = &event
d.Cache.StoreFederationQueuedEDU(nid, &event)
} }
return nil return nil
@ -177,20 +178,18 @@ func (d *Database) GetPendingEDUServerNames(
return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil) return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil)
} }
// DeleteExpiredEDUs deletes expired EDUs // DeleteExpiredEDUs deletes expired EDUs and evicts them from the cache.
func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { func (d *Database) DeleteExpiredEDUs(ctx context.Context) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var jsonNIDs []int64
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) (err error) {
expiredBefore := gomatrixserverlib.AsTimestamp(time.Now()) expiredBefore := gomatrixserverlib.AsTimestamp(time.Now())
jsonNIDs, err := d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) jsonNIDs, err = d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore)
if err != nil { if err != nil {
return err return err
} }
if len(jsonNIDs) == 0 { if len(jsonNIDs) == 0 {
return nil return nil
} }
for i := range jsonNIDs {
d.Cache.EvictFederationQueuedEDU(jsonNIDs[i])
}
if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil { if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil {
return err return err
@ -198,4 +197,14 @@ func (d *Database) DeleteExpiredEDUs(ctx context.Context) error {
return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore) return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore)
}) })
if err != nil {
return err
}
for i := range jsonNIDs {
d.Cache.EvictFederationQueuedEDU(jsonNIDs[i])
}
return nil
} }

View file

@ -31,7 +31,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat
func TestExpireEDUs(t *testing.T) { func TestExpireEDUs(t *testing.T) {
var expireEDUTypes = map[string]time.Duration{ var expireEDUTypes = map[string]time.Duration{
gomatrixserverlib.MReceipt: time.Millisecond, gomatrixserverlib.MReceipt: 0,
} }
ctx := context.Background() ctx := context.Background()

2
go.mod
View file

@ -21,7 +21,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.13 github.com/mattn/go-sqlite3 v1.14.13

4
go.sum
View file

@ -343,8 +343,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c h1:GhKmb8s9iXA9qsFD1SbiRo6Ee7cnbfcgJQ/iy43wczM= github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2 h1:esbNn9hg//tAStA6TogatAJAursw23A+yfVRQsdiv70=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220815094957-74b7ff4ae09c/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= github.com/matrix-org/gomatrixserverlib v0.0.0-20220830164018-c71e518537a2/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY=
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 9 VersionMinor = 9
VersionPatch = 3 VersionPatch = 5
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -55,14 +55,15 @@ func NewDeviceListUpdateConsumer(
// Start consuming from key servers // Start consuming from key servers
func (t *DeviceListUpdateConsumer) Start() error { func (t *DeviceListUpdateConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, 1,
nats.DeliverAll(), nats.ManualAck(), t.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called in response to a message received on the // onMessage is called in response to a message received on the
// key change events topic from the key server. // key change events topic from the key server.
func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var m gomatrixserverlib.DeviceListUpdateEvent var m gomatrixserverlib.DeviceListUpdateEvent
if err := json.Unmarshal(msg.Data, &m); err != nil { if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("Failed to read from device list update input topic") logrus.WithError(err).Errorf("Failed to read from device list update input topic")

View file

@ -335,8 +335,9 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
retriesMu := &sync.Mutex{} retriesMu := &sync.Mutex{}
// restarter goroutine which will inject failed servers into ch when it is time // restarter goroutine which will inject failed servers into ch when it is time
go func() { go func() {
for {
var serversToRetry []gomatrixserverlib.ServerName var serversToRetry []gomatrixserverlib.ServerName
for {
serversToRetry = serversToRetry[:0] // reuse memory
time.Sleep(time.Second) time.Sleep(time.Second)
retriesMu.Lock() retriesMu.Lock()
now := time.Now() now := time.Now()
@ -355,11 +356,17 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
} }
}() }()
for serverName := range ch { for serverName := range ch {
retriesMu.Lock()
_, exists := retries[serverName]
retriesMu.Unlock()
if exists {
// Don't retry a server that we're already waiting for.
continue
}
waitTime, shouldRetry := u.processServer(serverName) waitTime, shouldRetry := u.processServer(serverName)
if shouldRetry { if shouldRetry {
retriesMu.Lock() retriesMu.Lock()
_, exists := retries[serverName] if _, exists = retries[serverName]; !exists {
if !exists {
retries[serverName] = time.Now().Add(waitTime) retries[serverName] = time.Now().Add(waitTime)
} }
retriesMu.Unlock() retriesMu.Unlock()

View file

@ -60,7 +60,7 @@ func NewInternalAPI(
updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
ap.Updater = updater ap.Updater = updater
go func() { go func() {
if err = updater.Start(); err != nil { if err := updater.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start device list updater") logrus.WithError(err).Panicf("failed to start device list updater")
} }
}() }()
@ -68,7 +68,7 @@ func NewInternalAPI(
dlConsumer := consumers.NewDeviceListUpdateConsumer( dlConsumer := consumers.NewDeviceListUpdateConsumer(
base.ProcessContext, cfg, js, updater, base.ProcessContext, cfg, js, updater,
) )
if err = dlConsumer.Start(); err != nil { if err := dlConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start device list consumer") logrus.WithError(err).Panic("failed to start device list consumer")
} }

View file

@ -5,9 +5,10 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
) )
type PerformErrorCode int type PerformErrorCode int
@ -162,6 +163,7 @@ func (r *PerformBackfillRequest) PrevEventIDs() []string {
type PerformBackfillResponse struct { type PerformBackfillResponse struct {
// Missing events, arbritrary order. // Missing events, arbritrary order.
Events []*gomatrixserverlib.HeaderedEvent `json:"events"` Events []*gomatrixserverlib.HeaderedEvent `json:"events"`
HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"`
} }
type PerformPublishRequest struct { type PerformPublishRequest struct {

View file

@ -227,6 +227,7 @@ type QueryStateAndAuthChainResponse struct {
// Do all the previous events exist on this roomserver? // Do all the previous events exist on this roomserver?
// If some of previous events do not exist this will be false and StateEvents will be empty. // If some of previous events do not exist this will be false and StateEvents will be empty.
PrevEventsExist bool `json:"prev_events_exist"` PrevEventsExist bool `json:"prev_events_exist"`
StateKnown bool `json:"state_known"`
// The state and auth chain events that were requested. // The state and auth chain events that were requested.
// The lists will be in an arbitrary order. // The lists will be in an arbitrary order.
StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"`

View file

@ -19,6 +19,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// SendEvents to the roomserver The events are written with KindNew. // SendEvents to the roomserver The events are written with KindNew.
@ -69,6 +70,13 @@ func SendEventWithState(
stateEventIDs[i] = stateEvents[i].EventID() stateEventIDs[i] = stateEvents[i].EventID()
} }
logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": event.RoomID(),
"event_id": event.EventID(),
"outliers": len(ires),
"state_ids": len(stateEventIDs),
}).Infof("Submitting %q event to roomserver with state snapshot", event.Type())
ires = append(ires, InputRoomEvent{ ires = append(ires, InputRoomEvent{
Kind: kind, Kind: kind,
Event: event, Event: event,

View file

@ -39,7 +39,7 @@ func CheckForSoftFail(
var authStateEntries []types.StateEntry var authStateEntries []types.StateEntry
var err error var err error
if rewritesState { if rewritesState {
authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs) authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs, true)
if err != nil { if err != nil {
return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err) return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
} }
@ -97,7 +97,7 @@ func CheckAuthEvents(
authEventIDs []string, authEventIDs []string,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
// Grab the numeric IDs for the supplied auth state events from the database. // Grab the numeric IDs for the supplied auth state events from the database.
authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs) authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("db.StateEntriesForEventIDs: %w", err) return nil, fmt.Errorf("db.StateEntriesForEventIDs: %w", err)
} }

View file

@ -253,10 +253,17 @@ func CheckServerAllowedToSeeEvent(
if err != nil { if err != nil {
return false, err return false, err
} }
default:
switch err.(type) {
case types.MissingStateError:
// If there's no state then we assume it's open visibility, as Synapse does:
// https://github.com/matrix-org/synapse/blob/aec87a0f9369a3015b2a53469f88d1de274e8b71/synapse/visibility.py#L654-L655
return true, nil
default: default:
// Something else went wrong // Something else went wrong
return false, err return false, err
} }
}
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
} }

View file

@ -36,6 +36,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/producers"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
@ -247,6 +248,15 @@ func (w *worker) _next() {
// it was a synchronous request. // it was a synchronous request.
var errString string var errString string
if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil {
switch err.(type) {
case types.RejectedError:
// Don't send events that were rejected to Sentry
logrus.WithError(err).WithFields(logrus.Fields{
"room_id": w.roomID,
"event_id": inputRoomEvent.Event.EventID(),
"type": inputRoomEvent.Event.Type(),
}).Warn("Roomserver rejected event")
default:
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err) sentry.CaptureException(err)
} }
@ -254,7 +264,8 @@ func (w *worker) _next() {
"room_id": w.roomID, "room_id": w.roomID,
"event_id": inputRoomEvent.Event.EventID(), "event_id": inputRoomEvent.Event.EventID(),
"type": inputRoomEvent.Event.Type(), "type": inputRoomEvent.Event.Type(),
}).Warn("Roomserver failed to process async event") }).Warn("Roomserver failed to process event")
}
_ = msg.Term() _ = msg.Term()
errString = err.Error() errString = err.Error()
} else { } else {

View file

@ -301,7 +301,7 @@ func (r *Inputer) processRoomEvent(
// bother doing this if the event was already rejected as it just ends up // bother doing this if the event was already rejected as it just ends up
// burning CPU time. // burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if rejectionErr == nil && !isRejected && !softfail { if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected {
var err error var err error
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
if err != nil { if err != nil {
@ -313,7 +313,7 @@ func (r *Inputer) processRoomEvent(
} }
// Store the event. // Store the event.
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected || softfail) _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
@ -353,12 +353,18 @@ func (r *Inputer) processRoomEvent(
} }
} }
// We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. // We stop here if the event is rejected: We've stored it but won't update
if isRejected || softfail { // forward extremities or notify downstream components about it.
logger.WithError(rejectionErr).WithFields(logrus.Fields{ switch {
"soft_fail": softfail, case isRejected:
"missing_prev": missingPrev, logger.WithError(rejectionErr).Warn("Stored rejected event")
}).Warn("Stored rejected event") if rejectionErr != nil {
return types.RejectedError(rejectionErr.Error())
}
return nil
case softfail:
logger.WithError(rejectionErr).Warn("Stored soft-failed event")
if rejectionErr != nil { if rejectionErr != nil {
return types.RejectedError(rejectionErr.Error()) return types.RejectedError(rejectionErr.Error())
} }
@ -661,7 +667,7 @@ func (r *Inputer) calculateAndSetState(
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs, true); err != nil {
return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err) return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
} }
entries = types.DeduplicateStateEntries(entries) entries = types.DeduplicateStateEntries(entries)

View file

@ -18,7 +18,6 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -140,11 +139,11 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
continue continue
} }
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil { if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil {
// attempt to fetch the missing events // attempt to fetch the missing events
r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs)
// try again // try again
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs) entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
return err return err
@ -164,6 +163,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
res.Events = events res.Events = events
res.HistoryVisibility = requester.historyVisiblity
return nil return nil
} }
@ -248,6 +248,7 @@ type backfillRequester struct {
servers []gomatrixserverlib.ServerName servers []gomatrixserverlib.ServerName
eventIDToBeforeStateIDs map[string][]string eventIDToBeforeStateIDs map[string][]string
eventIDMap map[string]*gomatrixserverlib.Event eventIDMap map[string]*gomatrixserverlib.Event
historyVisiblity gomatrixserverlib.HistoryVisibility
} }
func newBackfillRequester( func newBackfillRequester(
@ -266,6 +267,7 @@ func newBackfillRequester(
eventIDMap: make(map[string]*gomatrixserverlib.Event), eventIDMap: make(map[string]*gomatrixserverlib.Event),
bwExtrems: bwExtrems, bwExtrems: bwExtrems,
preferServer: preferServer, preferServer: preferServer,
historyVisiblity: gomatrixserverlib.HistoryVisibilityShared,
} }
} }
@ -317,7 +319,6 @@ FederationHit:
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res
return res, nil return res, nil
} }
sentry.CaptureException(lastErr) // temporary to see if we might need to raise the server limit
return nil, lastErr return nil, lastErr
} }
@ -395,7 +396,6 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
} }
return result, nil return result, nil
} }
sentry.CaptureException(lastErr) // temporary to see if we might need to raise the server limit
return nil, lastErr return nil, lastErr
} }
@ -447,7 +447,8 @@ FindSuccessor:
} }
// possibly return all joined servers depending on history visiblity // possibly return all joined servers depending on history visiblity
memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer)
b.historyVisiblity = visibility
if err != nil { if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
return nil return nil
@ -528,7 +529,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
// pull all events and then filter by that table. // pull all events and then filter by that table.
func joinEventsFromHistoryVisibility( func joinEventsFromHistoryVisibility(
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry, ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry,
thisServer gomatrixserverlib.ServerName) ([]types.Event, error) { thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
for _, entry := range stateEntries { for _, entry := range stateEntries {
@ -542,7 +543,9 @@ func joinEventsFromHistoryVisibility(
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := db.Events(ctx, eventNIDs) stateEvents, err := db.Events(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, err // even though the default should be shared, restricting the visibility to joined
// feels more secure here.
return nil, gomatrixserverlib.HistoryVisibilityJoined, err
} }
events := make([]*gomatrixserverlib.Event, len(stateEvents)) events := make([]*gomatrixserverlib.Event, len(stateEvents))
for i := range stateEvents { for i := range stateEvents {
@ -551,20 +554,22 @@ func joinEventsFromHistoryVisibility(
// Can we see events in the room? // Can we see events in the room?
canSeeEvents := auth.IsServerAllowed(thisServer, true, events) canSeeEvents := auth.IsServerAllowed(thisServer, true, events)
visibility := gomatrixserverlib.HistoryVisibility(auth.HistoryVisibilityForRoom(events))
if !canSeeEvents { if !canSeeEvents {
logrus.Infof("ServersAtEvent history not visible to us: %s", auth.HistoryVisibilityForRoom(events)) logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
return nil, nil return nil, visibility, nil
} }
// get joined members // get joined members
info, err := db.RoomInfo(ctx, roomID) info, err := db.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
return nil, err return nil, visibility, nil
} }
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
if err != nil { if err != nil {
return nil, err return nil, visibility, err
} }
return db.Events(ctx, joinEventNIDs) evs, err := db.Events(ctx, joinEventNIDs)
return evs, visibility, err
} }
func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {

View file

@ -503,10 +503,11 @@ func (r *Queryer) QueryStateAndAuthChain(
} }
var stateEvents []*gomatrixserverlib.Event var stateEvents []*gomatrixserverlib.Event
stateEvents, rejected, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs) stateEvents, rejected, stateMissing, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs)
if err != nil { if err != nil {
return err return err
} }
response.StateKnown = !stateMissing
response.IsRejected = rejected response.IsRejected = rejected
response.PrevEventsExist = true response.PrevEventsExist = true
@ -542,15 +543,18 @@ func (r *Queryer) QueryStateAndAuthChain(
return err return err
} }
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, error) { // first bool: is rejected, second bool: state missing
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, bool, error) {
roomState := state.NewStateResolution(r.DB, roomInfo) roomState := state.NewStateResolution(r.DB, roomInfo)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case types.MissingEventError: case types.MissingEventError:
return nil, false, nil return nil, false, true, nil
case types.MissingStateError:
return nil, false, true, nil
default: default:
return nil, false, err return nil, false, false, err
} }
} }
// Currently only used on /state and /state_ids // Currently only used on /state and /state_ids
@ -567,12 +571,11 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
ctx, prevStates, ctx, prevStates,
) )
if err != nil { if err != nil {
return nil, rejected, err return nil, rejected, false, err
} }
events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
return events, rejected, false, err
return events, rejected, err
} }
type eventsFromIDs func(context.Context, []string) ([]types.Event, error) type eventsFromIDs func(context.Context, []string) ([]types.Event, error)

View file

@ -79,7 +79,7 @@ type Database interface {
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database // Returns an error if the there is an error talking to the database
// Returns a types.MissingEventError if the event IDs aren't in the database. // Returns a types.MissingEventError if the event IDs aren't in the database.
StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
// Look up the string event state keys for a list of numeric event state keys // Look up the string event state keys for a list of numeric event state keys
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)

View file

@ -74,7 +74,7 @@ const insertEventSQL = "" +
"INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
" ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" +
" SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = FALSE" + " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = TRUE" +
" RETURNING event_nid, state_snapshot_nid" " RETURNING event_nid, state_snapshot_nid"
const selectEventSQL = "" + const selectEventSQL = "" +
@ -88,6 +88,14 @@ const bulkSelectStateEventByIDSQL = "" +
" WHERE event_id = ANY($1)" + " WHERE event_id = ANY($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC" " ORDER BY event_type_nid, event_state_key_nid ASC"
// Bulk lookup of events by string ID that aren't excluded.
// Sort by the numeric IDs for event type and state key.
// This means we can use binary search to lookup entries by type and state key.
const bulkSelectStateEventByIDExcludingRejectedSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
" WHERE event_id = ANY($1) AND is_rejected = FALSE" +
" ORDER BY event_type_nid, event_state_key_nid ASC"
// Bulk look up of events by event NID, optionally filtering by the event type // Bulk look up of events by event NID, optionally filtering by the event type
// or event state key NIDs if provided. (The CARDINALITY check will return true // or event state key NIDs if provided. (The CARDINALITY check will return true
// if the provided arrays are empty, ergo no filtering). // if the provided arrays are empty, ergo no filtering).
@ -143,6 +151,7 @@ type eventStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt
bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt
bulkSelectStateEventByNIDStmt *sql.Stmt bulkSelectStateEventByNIDStmt *sql.Stmt
bulkSelectStateAtEventByIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt
updateEventStateStmt *sql.Stmt updateEventStateStmt *sql.Stmt
@ -171,6 +180,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventStmt, selectEventSQL}, {&s.selectEventStmt, selectEventSQL},
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
{&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL},
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL}, {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
{&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventStateStmt, updateEventStateSQL},
@ -221,11 +231,18 @@ func (s *eventStatements) SelectEvent(
} }
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If not excluding rejected events, and any of the requested events are missing from
// the database it returns a types.MissingEventError. If excluding rejected events,
// the events will be silently omitted without error.
func (s *eventStatements) BulkSelectStateEventByID( func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt) var stmt *sql.Stmt
if excludeRejected {
stmt = sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDExcludingRejectedStmt)
} else {
stmt = sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
}
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
@ -235,10 +252,10 @@ func (s *eventStatements) BulkSelectStateEventByID(
// because of the unique constraint on event IDs. // because of the unique constraint on event IDs.
// So we can allocate an array of the correct size now. // So we can allocate an array of the correct size now.
// We might get fewer results than IDs so we adjust the length of the slice before returning it. // We might get fewer results than IDs so we adjust the length of the slice before returning it.
results := make([]types.StateEntry, len(eventIDs)) results := make([]types.StateEntry, 0, len(eventIDs))
i := 0 i := 0
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
result := &results[i] var result types.StateEntry
if err = rows.Scan( if err = rows.Scan(
&result.EventTypeNID, &result.EventTypeNID,
&result.EventStateKeyNID, &result.EventStateKeyNID,
@ -246,11 +263,12 @@ func (s *eventStatements) BulkSelectStateEventByID(
); err != nil { ); err != nil {
return nil, err return nil, err
} }
results = append(results, result)
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
if i != len(eventIDs) { if !excludeRejected && i != len(eventIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query. // We don't know which ones were missing because we don't return the string IDs in the query.
// However it should be possible debug this by replaying queries or entries from the input kafka logs. // However it should be possible debug this by replaying queries or entries from the input kafka logs.
@ -328,7 +346,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
// Genuine create events are the only case where it's OK to have no previous state. // Genuine create events are the only case where it's OK to have no previous state.
isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1 isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1
if result.BeforeStateSnapshotNID == 0 && !isCreate { if result.BeforeStateSnapshotNID == 0 && !isCreate {
return nil, types.MissingEventError( return nil, types.MissingStateError(
fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
) )
} }

View file

@ -113,9 +113,9 @@ func (d *Database) eventStateKeyNIDs(
} }
func (d *Database) StateEntriesForEventIDs( func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string, excludeRejected bool,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs) return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected)
} }
func (d *Database) StateEntriesForTuples( func (d *Database) StateEntriesForTuples(

View file

@ -50,7 +50,7 @@ const insertEventSQL = `
INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT DO UPDATE ON CONFLICT DO UPDATE
SET is_rejected = $8 WHERE is_rejected = 0 SET is_rejected = $8 WHERE is_rejected = 1
RETURNING event_nid, state_snapshot_nid; RETURNING event_nid, state_snapshot_nid;
` `
@ -65,6 +65,14 @@ const bulkSelectStateEventByIDSQL = "" +
" WHERE event_id IN ($1)" + " WHERE event_id IN ($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC" " ORDER BY event_type_nid, event_state_key_nid ASC"
// Bulk lookup of events by string ID that aren't rejected.
// Sort by the numeric IDs for event type and state key.
// This means we can use binary search to lookup entries by type and state key.
const bulkSelectStateEventByIDExcludingRejectedSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
" WHERE event_id IN ($1) AND is_rejected = 0" +
" ORDER BY event_type_nid, event_state_key_nid ASC"
const bulkSelectStateEventByNIDSQL = "" + const bulkSelectStateEventByNIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
" WHERE event_nid IN ($1)" " WHERE event_nid IN ($1)"
@ -117,6 +125,7 @@ type eventStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt
bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt
bulkSelectStateAtEventByIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt
updateEventStateStmt *sql.Stmt updateEventStateStmt *sql.Stmt
selectEventSentToOutputStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt
@ -145,6 +154,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventStmt, selectEventSQL}, {&s.selectEventStmt, selectEventSQL},
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
{&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL},
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
{&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventStateStmt, updateEventStateSQL},
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
@ -194,16 +204,24 @@ func (s *eventStatements) SelectEvent(
} }
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If not excluding rejected events, and any of the requested events are missing from
// the database it returns a types.MissingEventError. If excluding rejected events,
// the events will be silently omitted without error.
func (s *eventStatements) BulkSelectStateEventByID( func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
/////////////// ///////////////
var sql string
if excludeRejected {
sql = bulkSelectStateEventByIDExcludingRejectedSQL
} else {
sql = bulkSelectStateEventByIDSQL
}
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs { for k, v := range eventIDs {
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) selectOrig := strings.Replace(sql, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -221,10 +239,10 @@ func (s *eventStatements) BulkSelectStateEventByID(
// because of the unique constraint on event IDs. // because of the unique constraint on event IDs.
// So we can allocate an array of the correct size now. // So we can allocate an array of the correct size now.
// We might get fewer results than IDs so we adjust the length of the slice before returning it. // We might get fewer results than IDs so we adjust the length of the slice before returning it.
results := make([]types.StateEntry, len(eventIDs)) results := make([]types.StateEntry, 0, len(eventIDs))
i := 0 i := 0
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
result := &results[i] var result types.StateEntry
if err = rows.Scan( if err = rows.Scan(
&result.EventTypeNID, &result.EventTypeNID,
&result.EventStateKeyNID, &result.EventStateKeyNID,
@ -232,8 +250,9 @@ func (s *eventStatements) BulkSelectStateEventByID(
); err != nil { ); err != nil {
return nil, err return nil, err
} }
results = append(results, result)
} }
if i != len(eventIDs) { if !excludeRejected && i != len(eventIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query. // We don't know which ones were missing because we don't return the string IDs in the query.
// However it should be possible debug this by replaying queries or entries from the input kafka logs. // However it should be possible debug this by replaying queries or entries from the input kafka logs.
@ -343,7 +362,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
// Genuine create events are the only case where it's OK to have no previous state. // Genuine create events are the only case where it's OK to have no previous state.
isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1 isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1
if result.BeforeStateSnapshotNID == 0 && !isCreate { if result.BeforeStateSnapshotNID == 0 && !isCreate {
return nil, types.MissingEventError( return nil, types.MissingStateError(
fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
) )
} }

View file

@ -102,7 +102,7 @@ func Test_EventsTable(t *testing.T) {
}) })
} }
stateEvents, err := tab.BulkSelectStateEventByID(ctx, nil, eventIDs) stateEvents, err := tab.BulkSelectStateEventByID(ctx, nil, eventIDs, false)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, len(stateEvents), len(eventIDs)) assert.Equal(t, len(stateEvents), len(eventIDs))
nids := make([]types.EventNID, 0, len(stateEvents)) nids := make([]types.EventNID, 0, len(stateEvents))

View file

@ -46,7 +46,7 @@ type Events interface {
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error) BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.

View file

@ -25,21 +25,23 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"time" "time"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
sentryhttp "github.com/getsentry/sentry-go/http" sentryhttp "github.com/getsentry/sentry-go/http"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/atomic" "go.uber.org/atomic"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
@ -47,6 +49,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/kardianos/minwinsvc" "github.com/kardianos/minwinsvc"
"github.com/sirupsen/logrus"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
asinthttp "github.com/matrix-org/dendrite/appservice/inthttp" asinthttp "github.com/matrix-org/dendrite/appservice/inthttp"
federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api"
@ -58,7 +62,6 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/sirupsen/logrus"
) )
// BaseDendrite is a base for creating new instances of dendrite. It parses // BaseDendrite is a base for creating new instances of dendrite. It parses
@ -87,6 +90,7 @@ type BaseDendrite struct {
Database *sql.DB Database *sql.DB
DatabaseWriter sqlutil.Writer DatabaseWriter sqlutil.Writer
EnableMetrics bool EnableMetrics bool
startupLock sync.Mutex
} }
const NoListener = "" const NoListener = ""
@ -394,6 +398,9 @@ func (b *BaseDendrite) SetupAndServeHTTP(
internalHTTPAddr, externalHTTPAddr config.HTTPAddress, internalHTTPAddr, externalHTTPAddr config.HTTPAddress,
certFile, keyFile *string, certFile, keyFile *string,
) { ) {
// Manually unlocked right before actually serving requests,
// as we don't return from this method (defer doesn't work).
b.startupLock.Lock()
internalAddr, _ := internalHTTPAddr.Address() internalAddr, _ := internalHTTPAddr.Address()
externalAddr, _ := externalHTTPAddr.Address() externalAddr, _ := externalHTTPAddr.Address()
@ -472,6 +479,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux) externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux)
externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux) externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux)
b.startupLock.Unlock()
if internalAddr != NoListener && internalAddr != externalAddr { if internalAddr != NoListener && internalAddr != externalAddr {
go func() { go func() {
var internalShutdown atomic.Bool // RegisterOnShutdown can be called more than once var internalShutdown atomic.Bool // RegisterOnShutdown can be called more than once

View file

@ -224,12 +224,7 @@ func loadConfig(
} }
privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath) privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath)
privateKeyData, err := readFile(privateKeyPath) if c.Global.KeyID, c.Global.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil {
if err != nil {
return nil, err
}
if c.Global.KeyID, c.Global.PrivateKey, err = readKeyPEM(privateKeyPath, privateKeyData, true); err != nil {
return nil, err return nil, err
} }
@ -265,6 +260,14 @@ func loadConfig(
return &c, nil return &c, nil
} }
func LoadMatrixKey(privateKeyPath string, readFile func(string) ([]byte, error)) (gomatrixserverlib.KeyID, ed25519.PrivateKey, error) {
privateKeyData, err := readFile(privateKeyPath)
if err != nil {
return "", nil, err
}
return readKeyPEM(privateKeyPath, privateKeyData, true)
}
// Derive generates data that is derived from various values provided in // Derive generates data that is derived from various values provided in
// the config file. // the config file.
func (config *Dendrite) Derive() error { func (config *Dendrite) Derive() error {

View file

@ -31,8 +31,6 @@ type AppServiceAPI struct {
InternalAPI InternalAPIOptions `yaml:"internal_api"` InternalAPI InternalAPIOptions `yaml:"internal_api"`
Database DatabaseOptions `yaml:"database"`
// DisableTLSValidation disables the validation of X.509 TLS certs // DisableTLSValidation disables the validation of X.509 TLS certs
// on appservice endpoints. This is not recommended in production! // on appservice endpoints. This is not recommended in production!
DisableTLSValidation bool `yaml:"disable_tls_validation"` DisableTLSValidation bool `yaml:"disable_tls_validation"`
@ -43,16 +41,9 @@ type AppServiceAPI struct {
func (c *AppServiceAPI) Defaults(generate bool) { func (c *AppServiceAPI) Defaults(generate bool) {
c.InternalAPI.Listen = "http://localhost:7777" c.InternalAPI.Listen = "http://localhost:7777"
c.InternalAPI.Connect = "http://localhost:7777" c.InternalAPI.Connect = "http://localhost:7777"
c.Database.Defaults(5)
if generate {
c.Database.ConnectionString = "file:appservice.db"
}
} }
func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { func (c *AppServiceAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
if c.Matrix.DatabaseOptions.ConnectionString == "" {
checkNotEmpty(configErrs, "app_service_api.database.connection_string", string(c.Database.ConnectionString))
}
if isMonolith { // polylith required configs below if isMonolith { // polylith required configs below
return return
} }

View file

@ -9,9 +9,16 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// JetStreamConsumer starts a durable consumer on the given subject with the
// given durable name. The function will be called when one or more messages
// is available, up to the maximum batch size specified. If the batch is set to
// 1 then messages will be delivered one at a time. If the function is called,
// the messages array is guaranteed to be at least 1 in size. Any provided NATS
// options will be passed through to the pull subscriber creation. The consumer
// will continue to run until the context expires, at which point it will stop.
func JetStreamConsumer( func JetStreamConsumer(
ctx context.Context, js nats.JetStreamContext, subj, durable string, ctx context.Context, js nats.JetStreamContext, subj, durable string, batch int,
f func(ctx context.Context, msg *nats.Msg) bool, f func(ctx context.Context, msgs []*nats.Msg) bool,
opts ...nats.SubOpt, opts ...nats.SubOpt,
) error { ) error {
defer func() { defer func() {
@ -50,7 +57,7 @@ func JetStreamConsumer(
// enforce its own deadline (roughly 5 seconds by default). Therefore // enforce its own deadline (roughly 5 seconds by default). Therefore
// it is our responsibility to check whether our context expired or // it is our responsibility to check whether our context expired or
// not when a context error is returned. Footguns. Footguns everywhere. // not when a context error is returned. Footguns. Footguns everywhere.
msgs, err := sub.Fetch(1, nats.Context(ctx)) msgs, err := sub.Fetch(batch, nats.Context(ctx))
if err != nil { if err != nil {
if err == context.Canceled || err == context.DeadlineExceeded { if err == context.Canceled || err == context.DeadlineExceeded {
// Work out whether it was the JetStream context that expired // Work out whether it was the JetStream context that expired
@ -74,24 +81,29 @@ func JetStreamConsumer(
if len(msgs) < 1 { if len(msgs) < 1 {
continue continue
} }
msg := msgs[0] for _, msg := range msgs {
if err = msg.InProgress(nats.Context(ctx)); err != nil { if err = msg.InProgress(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
sentry.CaptureException(err) sentry.CaptureException(err)
continue continue
} }
if f(ctx, msg) { }
if f(ctx, msgs) {
for _, msg := range msgs {
if err = msg.AckSync(nats.Context(ctx)); err != nil { if err = msg.AckSync(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err)) logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
sentry.CaptureException(err) sentry.CaptureException(err)
} }
}
} else { } else {
for _, msg := range msgs {
if err = msg.Nak(nats.Context(ctx)); err != nil { if err = msg.Nak(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
sentry.CaptureException(err) sentry.CaptureException(err)
} }
} }
} }
}
}() }()
return nil return nil
} }

View file

@ -183,6 +183,7 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
OutputReceiptEvent: {"SyncAPIEDUServerReceiptConsumer", "FederationAPIEDUServerConsumer"}, OutputReceiptEvent: {"SyncAPIEDUServerReceiptConsumer", "FederationAPIEDUServerConsumer"},
OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"}, OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"},
OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"}, OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"},
OutputRoomEvent: {"AppserviceRoomserverConsumer"},
} { } {
streamName := cfg.Matrix.JetStream.Prefixed(stream) streamName := cfg.Matrix.JetStream.Prefixed(stream)
for _, consumer := range consumers { for _, consumer := range consumers {

View file

@ -75,15 +75,16 @@ func NewOutputClientDataConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputClientDataConsumer) Start() error { func (s *OutputClientDataConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called when the sync server receives a new event from the client API server output log. // onMessage is called when the sync server receives a new event from the client API server output log.
// It is not safe for this function to be called from multiple goroutines, or else the // It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated. // sync stream position may race and be incorrectly calculated.
func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON // Parse out the event JSON
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
var output eventutil.AccountData var output eventutil.AccountData

View file

@ -75,12 +75,13 @@ func NewOutputKeyChangeEventConsumer(
// Start consuming from the key server // Start consuming from the key server
func (s *OutputKeyChangeEventConsumer) Start() error { func (s *OutputKeyChangeEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var m api.DeviceMessage var m api.DeviceMessage
if err := json.Unmarshal(msg.Data, &m); err != nil { if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic") logrus.WithError(err).Errorf("failed to read device message from key change topic")

View file

@ -128,12 +128,13 @@ func (s *PresenceConsumer) Start() error {
return nil return nil
} }
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.presenceTopic, s.durable, s.onMessage, s.ctx, s.jetstream, s.presenceTopic, s.durable, 1, s.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
) )
} }
func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *PresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
presence := msg.Header.Get("presence") presence := msg.Header.Get("presence")
timestamp := msg.Header.Get("last_active_ts") timestamp := msg.Header.Get("last_active_ts")

View file

@ -74,12 +74,13 @@ func NewOutputReceiptEventConsumer(
// Start consuming receipts events. // Start consuming receipts events.
func (s *OutputReceiptEventConsumer) Start() error { func (s *OutputReceiptEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
output := types.OutputReceiptEvent{ output := types.OutputReceiptEvent{
UserID: msg.Header.Get(jetstream.UserID), UserID: msg.Header.Get(jetstream.UserID),
RoomID: msg.Header.Get(jetstream.RoomID), RoomID: msg.Header.Get(jetstream.RoomID),

View file

@ -79,15 +79,16 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
// onMessage is called when the sync server receives a new event from the room server output log. // onMessage is called when the sync server receives a new event from the room server output log.
// It is not safe for this function to be called from multiple goroutines, or else the // It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated. // sync stream position may race and be incorrectly calculated.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON // Parse out the event JSON
var err error var err error
var output api.OutputEvent var output api.OutputEvent

View file

@ -68,12 +68,13 @@ func NewOutputSendToDeviceEventConsumer(
// Start consuming send-to-device events. // Start consuming send-to-device events.
func (s *OutputSendToDeviceEventConsumer) Start() error { func (s *OutputSendToDeviceEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {

View file

@ -64,12 +64,13 @@ func NewOutputTypingEventConsumer(
// Start consuming typing events. // Start consuming typing events.
func (s *OutputTypingEventConsumer) Start() error { func (s *OutputTypingEventConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
roomID := msg.Header.Get(jetstream.RoomID) roomID := msg.Header.Get(jetstream.RoomID)
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
typing, err := strconv.ParseBool(msg.Header.Get("typing")) typing, err := strconv.ParseBool(msg.Header.Get("typing"))

View file

@ -67,8 +67,8 @@ func NewOutputNotificationDataConsumer(
// Start starts consumption. // Start starts consumption.
func (s *OutputNotificationDataConsumer) Start() error { func (s *OutputNotificationDataConsumer) Start() error {
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
) )
} }
@ -76,7 +76,8 @@ func (s *OutputNotificationDataConsumer) Start() error {
// the push server. It is not safe for this function to be called from // the push server. It is not safe for this function to be called from
// multiple goroutines, or else the sync stream position may race and // multiple goroutines, or else the sync stream position may race and
// be incorrectly calculated. // be incorrectly calculated.
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := string(msg.Header.Get(jetstream.UserID)) userID := string(msg.Header.Get(jetstream.UserID))
// Parse out the event JSON // Parse out the event JSON

View file

@ -18,14 +18,16 @@ import (
"context" "context"
"strings" "strings"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
keytypes "github.com/matrix-org/dendrite/keyserver/types" keytypes "github.com/matrix-org/dendrite/keyserver/types"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// DeviceOTKCounts adds one-time key counts to the /sync response // DeviceOTKCounts adds one-time key counts to the /sync response
@ -125,7 +127,7 @@ func DeviceListCatchup(
"from": offset, "from": offset,
"to": toOffset, "to": toOffset,
"response_offset": queryRes.Offset, "response_offset": queryRes.Offset,
}).Debugf("QueryKeyChanges request result: %+v", res.DeviceLists) }).Tracef("QueryKeyChanges request result: %+v", res.DeviceLists)
return types.StreamPosition(queryRes.Offset), hasNew, nil return types.StreamPosition(queryRes.Offset), hasNew, nil
} }
@ -277,6 +279,10 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
// it's enough to know that we have our member event here, don't need to check membership content // it's enough to know that we have our member event here, don't need to check membership content
// as it's implied by being in the respective section of the sync response. // as it's implied by being in the respective section of the sync response.
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
// ignore e.g. join -> join changes
if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str {
continue
}
return true return true
} }
} }

View file

@ -352,6 +352,8 @@ func (r *messagesReq) retrieveEvents() (
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime), "duration": time.Since(startTime),
"room_id": r.roomID, "room_id": r.roomID,
"events_before": len(events),
"events_after": len(filteredEvents),
}).Debug("applied history visibility (messages)") }).Debug("applied history visibility (messages)")
return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err
} }
@ -513,6 +515,9 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
// Store the events in the database, while marking them as unfit to show // Store the events in the database, while marking them as unfit to show
// up in responses to sync requests. // up in responses to sync requests.
if res.HistoryVisibility == "" {
res.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
}
for i := range res.Events { for i := range res.Events {
_, err = r.db.WriteEvent( _, err = r.db.WriteEvent(
context.Background(), context.Background(),
@ -521,7 +526,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
[]string{}, []string{},
[]string{}, []string{},
nil, true, nil, true,
gomatrixserverlib.HistoryVisibilityShared, res.HistoryVisibility,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -534,6 +539,9 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
// last `limit` events // last `limit` events
events = events[len(events)-limit:] events = events[len(events)-limit:]
} }
for _, ev := range events {
ev.Visibility = res.HistoryVisibility
}
return events, nil return events, nil
} }

View file

@ -19,10 +19,11 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
type Database interface { type Database interface {

View file

@ -41,6 +41,8 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL content TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS syncapi_send_to_device_user_id_device_id_idx ON syncapi_send_to_device(user_id, device_id);
` `
const insertSendToDeviceMessageSQL = ` const insertSendToDeviceMessageSQL = `

View file

@ -20,15 +20,18 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/tidwall/gjson"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
) )
// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite // Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
@ -683,7 +686,7 @@ func (d *Database) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]types.StateDelta, []string, error) { ) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response // - Get membership list changes for this user in this sync response
// - For each room which has membership list changes: // - For each room which has membership list changes:
@ -718,8 +721,6 @@ func (d *Database) GetStateDeltas(
} }
} }
var deltas []types.StateDelta
// get all the state events ever (i.e. for all available rooms) between these two positions // get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
@ -767,15 +768,11 @@ func (d *Database) GetStateDeltas(
} }
// handle newly joined rooms and non-joined rooms // handle newly joined rooms and non-joined rooms
newlyJoinedRooms := make(map[string]bool, len(state))
for roomID, stateStreamEvents := range state { for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents { for _, ev := range stateStreamEvents {
// TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event. if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" {
// We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this, if membership == gomatrixserverlib.Join && prevMembership != membership {
// dupe join events will result in the entire room state coming down to the client again. This is added in
// the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
// the timeline.
if membership := getMembershipFromEvent(ev.Event, userID); membership != "" {
if membership == gomatrixserverlib.Join {
// send full room state down instead of a delta // send full room state down instead of a delta
var s []types.StreamEvent var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
@ -786,6 +783,7 @@ func (d *Database) GetStateDeltas(
return nil, nil, err return nil, nil, err
} }
state[roomID] = s state[roomID] = s
newlyJoinedRooms[roomID] = true
continue // we'll add this room in when we do joined rooms continue // we'll add this room in when we do joined rooms
} }
@ -806,6 +804,7 @@ func (d *Database) GetStateDeltas(
Membership: gomatrixserverlib.Join, Membership: gomatrixserverlib.Join,
StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
RoomID: joinedRoomID, RoomID: joinedRoomID,
NewlyJoined: newlyJoinedRooms[joinedRoomID],
}) })
} }
@ -892,7 +891,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
for roomID, stateStreamEvents := range state { for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents { for _, ev := range stateStreamEvents {
if membership := getMembershipFromEvent(ev.Event, userID); membership != "" { if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" {
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
deltas[roomID] = types.StateDelta{ deltas[roomID] = types.StateDelta{
Membership: membership, Membership: membership,
@ -1003,15 +1002,16 @@ func (d *Database) CleanSendToDeviceUpdates(
// getMembershipFromEvent returns the value of content.membership iff the event is a state event // getMembershipFromEvent returns the value of content.membership iff the event is a state event
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. // with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string { func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) {
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
return "" return "", ""
} }
membership, err := ev.Membership() membership, err := ev.Membership()
if err != nil { if err != nil {
return "" return "", ""
} }
return membership prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str
return membership, prevMembership
} }
// StoreReceipt stores user receipts // StoreReceipt stores user receipts

View file

@ -39,6 +39,8 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL content TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS syncapi_send_to_device_user_id_device_id_idx ON syncapi_send_to_device(user_id, device_id);
` `
const insertSendToDeviceMessageSQL = ` const insertSendToDeviceMessageSQL = `

View file

@ -178,24 +178,24 @@ func (p *PDUStreamProvider) IncrementalSync(
var err error var err error
var stateDeltas []types.StateDelta var stateDeltas []types.StateDelta
var joinedRooms []string var syncJoinedRooms []string
stateFilter := req.Filter.Room.State stateFilter := req.Filter.Room.State
eventFilter := req.Filter.Room.Timeline eventFilter := req.Filter.Room.Timeline
if req.WantFullState { if req.WantFullState {
if stateDeltas, joinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
return return
} }
} else { } else {
if stateDeltas, joinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
return return
} }
} }
for _, roomID := range joinedRooms { for _, roomID := range syncJoinedRooms {
req.Rooms[roomID] = gomatrixserverlib.Join req.Rooms[roomID] = gomatrixserverlib.Join
} }
@ -209,11 +209,27 @@ func (p *PDUStreamProvider) IncrementalSync(
newPos = from newPos = from
for _, delta := range stateDeltas { for _, delta := range stateDeltas {
newRange := r
// If this room was joined in this sync, try to fetch
// as much timeline events as allowed by the filter.
if delta.NewlyJoined {
// Reverse the range, so we get the most recent first.
// This will be limited by the eventFilter.
newRange = types.Range{
From: r.To,
To: 0,
Backwards: true,
}
}
var pos types.StreamPosition var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil { if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return to return to
} }
// Reset the position, as it is only for the special case of newly joined rooms
if delta.NewlyJoined {
pos = newRange.From
}
switch { switch {
case r.Backwards && pos < newPos: case r.Backwards && pos < newPos:
fallthrough fallthrough
@ -287,7 +303,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
if stateFilter.LazyLoadMembers { if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers( delta.StateEvents, err = p.lazyLoadMembers(
ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers, ctx, delta.RoomID, true, limited, stateFilter,
device, recentEvents, delta.StateEvents, device, recentEvents, delta.StateEvents,
) )
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -309,12 +325,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
} }
if len(events) > 0 {
updateLatestPosition(events[len(events)-1].EventID())
}
if len(delta.StateEvents) > 0 { if len(delta.StateEvents) > 0 {
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
} }
if len(events) > 0 {
updateLatestPosition(events[len(events)-1].EventID())
}
switch delta.Membership { switch delta.Membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
@ -387,6 +403,8 @@ func applyHistoryVisibilityFilter(
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime), "duration": time.Since(startTime),
"room_id": roomID, "room_id": roomID,
"before": len(recentEvents),
"after": len(events),
}).Debug("applied history visibility (sync)") }).Debug("applied history visibility (sync)")
return events, nil return events, nil
} }
@ -514,7 +532,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
return nil, err return nil, err
} }
stateEvents, err = p.lazyLoadMembers(ctx, roomID, stateEvents, err = p.lazyLoadMembers(ctx, roomID,
false, limited, stateFilter.IncludeRedundantMembers, false, limited, stateFilter,
device, recentEvents, stateEvents, device, recentEvents, stateEvents,
) )
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -533,7 +551,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
func (p *PDUStreamProvider) lazyLoadMembers( func (p *PDUStreamProvider) lazyLoadMembers(
ctx context.Context, roomID string, ctx context.Context, roomID string,
incremental, limited, includeRedundant bool, incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter,
device *userapi.Device, device *userapi.Device,
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
@ -563,7 +581,7 @@ func (p *PDUStreamProvider) lazyLoadMembers(
stateKey := *event.StateKey() stateKey := *event.StateKey()
if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental { if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental {
newStateEvents = append(newStateEvents, event) newStateEvents = append(newStateEvents, event)
if !includeRedundant { if !stateFilter.IncludeRedundantMembers {
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID())
} }
delete(timelineUsers, stateKey) delete(timelineUsers, stateKey)
@ -578,6 +596,7 @@ func (p *PDUStreamProvider) lazyLoadMembers(
} }
// Query missing membership events // Query missing membership events
filter := gomatrixserverlib.DefaultStateFilter() filter := gomatrixserverlib.DefaultStateFilter()
filter.Limit = stateFilter.Limit
filter.Senders = &wantUsers filter.Senders = &wantUsers
filter.Types = &[]string{gomatrixserverlib.MRoomMember} filter.Types = &[]string{gomatrixserverlib.MRoomMember}
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter) memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter)

View file

@ -19,9 +19,11 @@ import (
"encoding/json" "encoding/json"
"sync" "sync"
"github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
) )
type PresenceStreamProvider struct { type PresenceStreamProvider struct {
@ -175,6 +177,10 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
// it's enough to know that we have our member event here, don't need to check membership content // it's enough to know that we have our member event here, don't need to check membership content
// as it's implied by being in the respective section of the sync response. // as it's implied by being in the respective section of the sync response.
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
// ignore e.g. join -> join changes
if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str {
continue
}
return true return true
} }
} }

View file

@ -23,12 +23,13 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
) )
const defaultSyncTimeout = time.Duration(0) const defaultSyncTimeout = time.Duration(0)
@ -46,15 +47,9 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
return nil, err return nil, err
} }
} }
// TODO: read from stored filters too
// Create a default filter and apply a stored filter on top of it (if specified)
filter := gomatrixserverlib.DefaultFilter() filter := gomatrixserverlib.DefaultFilter()
if since.IsEmpty() {
// Send as much account data down for complete syncs as possible
// by default, otherwise clients do weird things while waiting
// for the rest of the data to trickle down.
filter.AccountData.Limit = math.MaxInt32
filter.Room.AccountData.Limit = math.MaxInt32
}
filterQuery := req.URL.Query().Get("filter") filterQuery := req.URL.Query().Get("filter")
if filterQuery != "" { if filterQuery != "" {
if filterQuery[0] == '{' { if filterQuery[0] == '{' {
@ -76,6 +71,17 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
} }
} }
// A loaded filter might have overwritten these values,
// so set them after loading the filter.
if since.IsEmpty() {
// Send as much account data down for complete syncs as possible
// by default, otherwise clients do weird things while waiting
// for the rest of the data to trickle down.
filter.AccountData.Limit = math.MaxInt32
filter.Room.AccountData.Limit = math.MaxInt32
filter.Room.State.Limit = math.MaxInt32
}
logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{ logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{
"user_id": device.UserID, "user_id": device.UserID,
"device_id": device.ID, "device_id": device.ID,

View file

@ -298,8 +298,8 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
return giveup() return giveup()
case <-userStreamListener.GetNotifyChannel(syncReq.Since): case <-userStreamListener.GetNotifyChannel(syncReq.Since):
syncReq.Log.Debugln("Responding to sync after wake-up")
currentPos.ApplyUpdates(userStreamListener.GetSyncPosition()) currentPos.ApplyUpdates(userStreamListener.GetSyncPosition())
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync after wake-up")
} }
} else { } else {
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")

View file

@ -154,8 +154,12 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
wantJoinedRooms: []string{room.ID}, wantJoinedRooms: []string{room.ID},
}, },
} }
// TODO: find a better way
time.Sleep(500 * time.Millisecond) syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool {
// wait for the last sent eventID to come down sync
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID())
return gjson.Get(syncBody, path).Exists()
})
for _, tc := range testCases { for _, tc := range testCases {
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -191,6 +195,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) {
} }
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
t.Skip("Skipped, possibly fixed")
user := test.NewUser(t) user := test.NewUser(t)
room := test.NewRoom(t, user) room := test.NewRoom(t, user)
alice := userapi.Device{ alice := userapi.Device{
@ -343,6 +348,13 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
// create the users // create the users
alice := test.NewUser(t) alice := test.NewUser(t)
aliceDev := userapi.Device{
ID: "ALICEID",
UserID: alice.ID,
AccessToken: "ALICE_BEARER_TOKEN",
DisplayName: "ALICE",
}
bob := test.NewUser(t) bob := test.NewUser(t)
bobDev := userapi.Device{ bobDev := userapi.Device{
@ -409,7 +421,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
rsAPI := roomserver.NewInternalAPI(base) rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil) rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, rsAPI, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{})
for _, tc := range testCases { for _, tc := range testCases {
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
@ -418,12 +430,18 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility)) room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility))
// send the events/messages to NATS to create the rooms // send the events/messages to NATS to create the rooms
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)}) beforeJoinBody := fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": beforeJoinBody})
eventsToSend := append(room.Events(), beforeJoinEv) eventsToSend := append(room.Events(), beforeJoinEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err) t.Fatalf("failed to send events: %v", err)
} }
time.Sleep(100 * time.Millisecond) // TODO: find a better way syncUntil(t, base, aliceDev.AccessToken, false,
func(syncBody string) bool {
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(content.body=="%s")`, room.ID, beforeJoinBody)
return gjson.Get(syncBody, path).Exists()
},
)
// There is only one event, we expect only to be able to see this, if the room is world_readable // There is only one event, we expect only to be able to see this, if the room is world_readable
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -449,14 +467,20 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID)) inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID))
afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)}) afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)})
joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID)) joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID))
msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)}) afterJoinBody := fmt.Sprintf("After join in a %s room", tc.historyVisibility)
msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": afterJoinBody})
eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err) t.Fatalf("failed to send events: %v", err)
} }
time.Sleep(100 * time.Millisecond) // TODO: find a better way syncUntil(t, base, aliceDev.AccessToken, false,
func(syncBody string) bool {
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(content.body=="%s")`, room.ID, afterJoinBody)
return gjson.Get(syncBody, path).Exists()
},
)
// Verify the messages after/before invite are visible or not // Verify the messages after/before invite are visible or not
w = httptest.NewRecorder() w = httptest.NewRecorder()
@ -511,8 +535,8 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
} }
base, close := testrig.CreateBaseDendrite(t, dbType) base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer close() defer baseClose()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
@ -607,7 +631,14 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
t.Fatalf("unable to send to device message: %v", err) t.Fatalf("unable to send to device message: %v", err)
} }
} }
time.Sleep((time.Millisecond * 15) * time.Duration(tc.sendMessagesCount)) // wait a bit, so the messages can be processed
syncUntil(t, base, alice.AccessToken,
len(tc.want) == 0,
func(body string) bool {
return gjson.Get(body, fmt.Sprintf(`to_device.events.#(content.dummy=="message %d")`, msgCounter)).Exists()
},
)
// Execute a /sync request, recording the response // Execute a /sync request, recording the response
w := httptest.NewRecorder() w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
@ -630,6 +661,42 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
} }
} }
func syncUntil(t *testing.T,
base *base.BaseDendrite, accessToken string,
skip bool,
checkFunc func(syncBody string) bool,
) {
if checkFunc == nil {
t.Fatalf("No checkFunc defined")
}
if skip {
return
}
// loop on /sync until we receive the last send message or timeout after 5 seconds, since we don't know if the message made it
// to the syncAPI when hitting /sync
done := make(chan bool)
defer close(done)
go func() {
for {
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
"access_token": accessToken,
"timeout": "1000",
})))
if checkFunc(w.Body.String()) {
done <- true
return
}
}
}()
select {
case <-done:
case <-time.After(time.Second * 5):
t.Fatalf("Timed out waiting for messages")
}
}
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg { func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
result := make([]*nats.Msg, len(input)) result := make([]*nats.Msg, len(input))
for i, ev := range input { for i, ev := range input {

View file

@ -37,6 +37,7 @@ var (
type StateDelta struct { type StateDelta struct {
RoomID string RoomID string
StateEvents []*gomatrixserverlib.HeaderedEvent StateEvents []*gomatrixserverlib.HeaderedEvent
NewlyJoined bool
Membership string Membership string
// The PDU stream position of the latest membership event for this user, if applicable. // The PDU stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta. // Can be 0 if there is no membership event in this delta.

View file

@ -144,7 +144,6 @@ Server correctly handles incoming m.device_list_update
If remote user leaves room, changes device and rejoins we see update in sync If remote user leaves room, changes device and rejoins we see update in sync
If remote user leaves room, changes device and rejoins we see update in /keys/changes If remote user leaves room, changes device and rejoins we see update in /keys/changes
If remote user leaves room we no longer receive device updates If remote user leaves room we no longer receive device updates
If a device list update goes missing, the server resyncs on the next one
Server correctly resyncs when client query keys and there is no remote cache Server correctly resyncs when client query keys and there is no remote cache
Server correctly resyncs when server leaves and rejoins a room Server correctly resyncs when server leaves and rejoins a room
Device list doesn't change if remote server is down Device list doesn't change if remote server is down
@ -633,7 +632,6 @@ Test that rejected pushers are removed.
Trying to add push rule with no scope fails with 400 Trying to add push rule with no scope fails with 400
Trying to add push rule with invalid scope fails with 400 Trying to add push rule with invalid scope fails with 400
Forward extremities remain so even after the next events are populated as outliers Forward extremities remain so even after the next events are populated as outliers
If a device list update goes missing, the server resyncs on the next one
uploading self-signing key notifies over federation uploading self-signing key notifies over federation
uploading signed devices gets propagated over federation uploading signed devices gets propagated over federation
Device list doesn't change if remote server is down Device list doesn't change if remote server is down

View file

@ -68,7 +68,7 @@ func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL str
if withTLS { if withTLS {
certFile := filepath.Join(t.TempDir(), "dendrite.cert") certFile := filepath.Join(t.TempDir(), "dendrite.cert")
keyFile := filepath.Join(t.TempDir(), "dendrite.key") keyFile := filepath.Join(t.TempDir(), "dendrite.key")
err = NewTLSKey(keyFile, certFile) err = NewTLSKey(keyFile, certFile, 1024)
if err != nil { if err != nil {
t.Errorf("failed to make TLS key: %s", err) t.Errorf("failed to make TLS key: %s", err)
return return

View file

@ -15,6 +15,7 @@
package test package test
import ( import (
"crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
@ -44,6 +45,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
if err != nil { if err != nil {
return err return err
} }
return SaveMatrixKey(matrixKeyPath, data[3:])
}
func SaveMatrixKey(matrixKeyPath string, data ed25519.PrivateKey) error {
keyOut, err := os.OpenFile(matrixKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) keyOut, err := os.OpenFile(matrixKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil { if err != nil {
return err return err
@ -62,15 +67,15 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
Headers: map[string]string{ Headers: map[string]string{
"Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]), "Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]),
}, },
Bytes: data[3:], Bytes: data,
}) })
return err return err
} }
const certificateDuration = time.Hour * 24 * 365 * 10 const certificateDuration = time.Hour * 24 * 365 * 10
func generateTLSTemplate(dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) { func generateTLSTemplate(dnsNames []string, bitSize int) (*rsa.PrivateKey, *x509.Certificate, error) {
priv, err := rsa.GenerateKey(rand.Reader, 4096) priv, err := rsa.GenerateKey(rand.Reader, bitSize)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -118,8 +123,8 @@ func writePrivateKey(tlsKeyPath string, priv *rsa.PrivateKey) error {
} }
// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file. // NewTLSKey generates a new RSA TLS key and certificate and writes it to a file.
func NewTLSKey(tlsKeyPath, tlsCertPath string) error { func NewTLSKey(tlsKeyPath, tlsCertPath string, keySize int) error {
priv, template, err := generateTLSTemplate(nil) priv, template, err := generateTLSTemplate(nil, keySize)
if err != nil { if err != nil {
return err return err
} }
@ -136,8 +141,8 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
return writePrivateKey(tlsKeyPath, priv) return writePrivateKey(tlsKeyPath, priv)
} }
func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string) error { func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string, keySize int) error {
priv, template, err := generateTLSTemplate([]string{serverName}) priv, template, err := generateTLSTemplate([]string{serverName}, keySize)
if err != nil { if err != nil {
return err return err
} }

View file

@ -57,7 +57,6 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() { return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() {
// cleanup db files. This risks getting out of sync as we add more database strings :( // cleanup db files. This risks getting out of sync as we add more database strings :(
dbFiles := []config.DataSource{ dbFiles := []config.DataSource{
cfg.AppServiceAPI.Database.ConnectionString,
cfg.FederationAPI.Database.ConnectionString, cfg.FederationAPI.Database.ConnectionString,
cfg.KeyServer.Database.ConnectionString, cfg.KeyServer.Database.ConnectionString,
cfg.MSCs.Database.ConnectionString, cfg.MSCs.Database.ConnectionString,

View file

@ -56,15 +56,16 @@ func NewOutputReadUpdateConsumer(
func (s *OutputReadUpdateConsumer) Start() error { func (s *OutputReadUpdateConsumer) Start() error {
if err := jetstream.JetStreamConsumer( if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil { ); err != nil {
return err return err
} }
return nil return nil
} }
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var read types.ReadUpdate var read types.ReadUpdate
if err := json.Unmarshal(msg.Data, &read); err != nil { if err := json.Unmarshal(msg.Data, &read); err != nil {
log.WithError(err).Error("userapi clientapi consumer: message parse failure") log.WithError(err).Error("userapi clientapi consumer: message parse failure")

View file

@ -7,6 +7,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
@ -20,9 +24,6 @@ import (
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/util" "github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
) )
type OutputStreamEventConsumer struct { type OutputStreamEventConsumer struct {
@ -64,15 +65,16 @@ func NewOutputStreamEventConsumer(
func (s *OutputStreamEventConsumer) Start() error { func (s *OutputStreamEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer( if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, 1,
nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil { ); err != nil {
return err return err
} }
return nil return nil
} }
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.StreamedEvent var output types.StreamedEvent
output.Event = &gomatrixserverlib.HeaderedEvent{} output.Event = &gomatrixserverlib.HeaderedEvent{}
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
@ -529,7 +531,9 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
case "event_id_only": case "event_id_only":
req = pushgateway.NotifyRequest{ req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{ Notification: pushgateway.Notification{
Counts: &pushgateway.Counts{}, Counts: &pushgateway.Counts{
Unread: userNumUnreadNotifs,
},
Devices: devices, Devices: devices,
EventID: event.EventID(), EventID: event.EventID(),
RoomID: event.RoomID(), RoomID: event.RoomID(),

View file

@ -28,7 +28,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
@ -454,7 +453,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// Create a dummy device for AS user // Create a dummy device for AS user
dev := api.Device{ dev := api.Device{
// Use AS dummy device ID // Use AS dummy device ID
ID: types.AppServiceDeviceID, ID: "AS_Device",
// AS dummy device has AS's token. // AS dummy device has AS's token.
AccessToken: token, AccessToken: token,
AppserviceID: appService.ID, AppserviceID: appService.ID,